diff --git a/docs/source/tutorials/diabetes_130/readmission_prediction_detectron.ipynb b/docs/source/tutorials/diabetes_130/readmission_prediction_detectron.ipynb
index 7be2dd414..a877f6af0 100644
--- a/docs/source/tutorials/diabetes_130/readmission_prediction_detectron.ipynb
+++ b/docs/source/tutorials/diabetes_130/readmission_prediction_detectron.ipynb
@@ -4,9 +4,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Readmission Prediction\n",
+ "# Readmission Prediction Detectron Implementation\n",
"\n",
- "This notebook showcases readmission prediction on the [Diabetes 130-US Hospitals for Years 1999-2008](https://archive.ics.uci.edu/dataset/296/diabetes+130-us+hospitals+for+years+1999-2008) using CyclOps. The task is formulated as a binary classification task, where we predict the probability of early readmission of the patient within 30 days of discharge."
+ "This notebook showcases readmission prediction on the [Diabetes 130-US Hospitals for Years 1999-2008](https://archive.ics.uci.edu/dataset/296/diabetes+130-us+hospitals+for+years+1999-2008) using CyclOps. The task is formulated as a binary classification task, where we predict the probability of early readmission of the patient within 30 days of discharge. The model health is then evaluated on a\n",
+ "held-out test set using the [Detectron](https://github.com/rgklab/detectron) method."
]
},
{
@@ -20,107 +21,11 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {
"tags": []
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Requirement already satisfied: pycyclops[xgboost] in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (0.2.9)\n",
- "Requirement already satisfied: Jinja2<4.0.0,>=3.1.3 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (3.1.4)\n",
- "Requirement already satisfied: array-api-compat==1.6 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.6)\n",
- "Requirement already satisfied: datasets<3.0.0,>=2.15.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (2.19.0)\n",
- "Requirement already satisfied: hydra-core<2.0.0,>=1.2.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.3.2)\n",
- "Requirement already satisfied: kaleido==0.2.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (0.2.1)\n",
- "Requirement already satisfied: matplotlib<4.0.0,>=3.8.3 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (3.8.3)\n",
- "Requirement already satisfied: numpy<2.0.0,>=1.24.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.24.4)\n",
- "Requirement already satisfied: pandas<3.0,>=2.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (2.1.4)\n",
- "Requirement already satisfied: pillow<11.0.0,>=10.0.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (10.3.0)\n",
- "Requirement already satisfied: plotly<6.0.0,>=5.7.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (5.18.0)\n",
- "Requirement already satisfied: psutil<6.0.0,>=5.9.4 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (5.9.7)\n",
- "Requirement already satisfied: pyarrow<15.0.0,>=14.0.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (14.0.2)\n",
- "Requirement already satisfied: pybtex<0.25.0,>=0.24.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (0.24.0)\n",
- "Requirement already satisfied: pydantic<2.0.0,>=1.10.11 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.10.13)\n",
- "Requirement already satisfied: scikit-learn<2.0.0,>=1.4.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.5.0)\n",
- "Requirement already satisfied: scipy<2.0.0,>=1.11.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.13.0rc1)\n",
- "Requirement already satisfied: scour<0.39.0,>=0.38.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (0.38.2)\n",
- "Requirement already satisfied: spdx-tools<0.9.0,>=0.8.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (0.8.2)\n",
- "Requirement already satisfied: xgboost<2.0.0,>=1.5.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.7.6)\n",
- "Requirement already satisfied: filelock in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (3.13.1)\n",
- "Requirement already satisfied: pyarrow-hotfix in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (0.6)\n",
- "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (0.3.7)\n",
- "Requirement already satisfied: requests>=2.19.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (2.32.0)\n",
- "Requirement already satisfied: tqdm>=4.62.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (4.66.4)\n",
- "Requirement already satisfied: xxhash in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (3.4.1)\n",
- "Requirement already satisfied: multiprocess in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (0.70.15)\n",
- "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (2023.10.0)\n",
- "Requirement already satisfied: aiohttp in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (3.9.5)\n",
- "Requirement already satisfied: huggingface-hub>=0.21.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (0.22.2)\n",
- "Requirement already satisfied: packaging in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (23.2)\n",
- "Requirement already satisfied: pyyaml>=5.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (6.0.1)\n",
- "Requirement already satisfied: omegaconf<2.4,>=2.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from hydra-core<2.0.0,>=1.2.0->pycyclops[xgboost]) (2.3.0)\n",
- "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from hydra-core<2.0.0,>=1.2.0->pycyclops[xgboost]) (4.9.3)\n",
- "Requirement already satisfied: MarkupSafe>=2.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from Jinja2<4.0.0,>=3.1.3->pycyclops[xgboost]) (2.1.3)\n",
- "Requirement already satisfied: contourpy>=1.0.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (1.1.0)\n",
- "Requirement already satisfied: cycler>=0.10 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (0.12.1)\n",
- "Requirement already satisfied: fonttools>=4.22.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (4.47.0)\n",
- "Requirement already satisfied: kiwisolver>=1.3.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (1.4.5)\n",
- "Requirement already satisfied: pyparsing>=2.3.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (3.1.1)\n",
- "Requirement already satisfied: python-dateutil>=2.7 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (2.8.2)\n",
- "Requirement already satisfied: importlib-resources>=3.2.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (6.1.1)\n",
- "Requirement already satisfied: pytz>=2020.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas<3.0,>=2.1->pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (2023.3.post1)\n",
- "Requirement already satisfied: tzdata>=2022.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas<3.0,>=2.1->pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (2023.3)\n",
- "Requirement already satisfied: bottleneck>=1.3.4 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (1.3.8)\n",
- "Requirement already satisfied: numba>=0.55.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (0.57.1)\n",
- "Requirement already satisfied: numexpr>=2.8.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (2.10.0)\n",
- "Requirement already satisfied: tenacity>=6.2.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from plotly<6.0.0,>=5.7.0->pycyclops[xgboost]) (8.2.3)\n",
- "Requirement already satisfied: latexcodec>=1.0.4 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pybtex<0.25.0,>=0.24.0->pycyclops[xgboost]) (2.0.1)\n",
- "Requirement already satisfied: six in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pybtex<0.25.0,>=0.24.0->pycyclops[xgboost]) (1.16.0)\n",
- "Requirement already satisfied: typing-extensions>=4.2.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pydantic<2.0.0,>=1.10.11->pycyclops[xgboost]) (4.9.0)\n",
- "Requirement already satisfied: joblib>=1.2.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from scikit-learn<2.0.0,>=1.4.0->pycyclops[xgboost]) (1.3.2)\n",
- "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from scikit-learn<2.0.0,>=1.4.0->pycyclops[xgboost]) (3.2.0)\n",
- "Requirement already satisfied: click in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (8.1.7)\n",
- "Requirement already satisfied: xmltodict in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (0.13.0)\n",
- "Requirement already satisfied: rdflib in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (7.0.0)\n",
- "Requirement already satisfied: beartype in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (0.16.4)\n",
- "Requirement already satisfied: uritools in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (4.0.2)\n",
- "Requirement already satisfied: license-expression in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (30.2.0)\n",
- "Requirement already satisfied: ply in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (3.11)\n",
- "Requirement already satisfied: semantic-version in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (2.10.0)\n",
- "Requirement already satisfied: aiosignal>=1.1.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from aiohttp->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (1.3.1)\n",
- "Requirement already satisfied: attrs>=17.3.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from aiohttp->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (23.1.0)\n",
- "Requirement already satisfied: frozenlist>=1.1.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from aiohttp->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (1.4.1)\n",
- "Requirement already satisfied: multidict<7.0,>=4.5 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from aiohttp->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (6.0.4)\n",
- "Requirement already satisfied: yarl<2.0,>=1.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from aiohttp->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (1.9.4)\n",
- "Requirement already satisfied: async-timeout<5.0,>=4.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from aiohttp->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (4.0.3)\n",
- "Requirement already satisfied: zipp>=3.1.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from importlib-resources>=3.2.0->matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (3.17.0)\n",
- "Requirement already satisfied: llvmlite<0.41,>=0.40.0dev0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from numba>=0.55.2->pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (0.40.1)\n",
- "Requirement already satisfied: charset-normalizer<4,>=2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from requests>=2.19.0->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (3.3.2)\n",
- "Requirement already satisfied: idna<4,>=2.5 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from requests>=2.19.0->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (3.7)\n",
- "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from requests>=2.19.0->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (2.2.2)\n",
- "Requirement already satisfied: certifi>=2017.4.17 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from requests>=2.19.0->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (2023.11.17)\n",
- "Requirement already satisfied: boolean.py>=4.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from license-expression->spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (4.0)\n",
- "Requirement already satisfied: isodate<0.7.0,>=0.6.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from rdflib->spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (0.6.1)\n",
- "\n",
- "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n",
- "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
- "Requirement already satisfied: ucimlrepo in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (0.0.7)\n",
- "Requirement already satisfied: pandas>=1.0.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from ucimlrepo) (2.1.4)\n",
- "Requirement already satisfied: certifi>=2020.12.5 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from ucimlrepo) (2023.11.17)\n",
- "Requirement already satisfied: numpy<2,>=1.22.4 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas>=1.0.0->ucimlrepo) (1.24.4)\n",
- "Requirement already satisfied: python-dateutil>=2.8.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas>=1.0.0->ucimlrepo) (2.8.2)\n",
- "Requirement already satisfied: pytz>=2020.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas>=1.0.0->ucimlrepo) (2023.3.post1)\n",
- "Requirement already satisfied: tzdata>=2022.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas>=1.0.0->ucimlrepo) (2023.3)\n",
- "Requirement already satisfied: six>=1.5 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas>=1.0.0->ucimlrepo) (1.16.0)\n",
- "\n",
- "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n",
- "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"!pip install pycyclops[xgboost]\n",
"!pip install ucimlrepo"
@@ -135,33 +40,21 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {
"tags": []
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"\"\"\"Readmission prediction.\"\"\"\n",
"\n",
"# ruff: noqa: E402\n",
"\n",
- "import copy\n",
- "import inspect\n",
- "from datetime import date\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import plotly.express as px\n",
- "from datasets import Dataset\n",
+ "from datasets import Dataset, DatasetDict\n",
"from datasets.features import ClassLabel\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.impute import SimpleImputer\n",
@@ -170,23 +63,9 @@
"from ucimlrepo import fetch_ucirepo\n",
"\n",
"from cyclops.data.df.feature import TabularFeatures\n",
- "from cyclops.data.slicer import SliceSpec\n",
- "from cyclops.evaluate.fairness import FairnessConfig # noqa: E402\n",
- "from cyclops.evaluate.metrics import create_metric\n",
- "from cyclops.evaluate.metrics.experimental.functional import (\n",
- " binary_npv,\n",
- " binary_ppv,\n",
- " binary_roc,\n",
- ")\n",
- "from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict\n",
"from cyclops.models.catalog import create_model\n",
- "from cyclops.report import ModelCardReport\n",
- "from cyclops.report.plot.classification import ClassificationPlotter\n",
- "from cyclops.report.utils import flatten_results_dict\n",
- "from cyclops.tasks import BinaryTabularClassificationTask\n",
- "\n",
"from cyclops.monitor.tester import Detectron\n",
- "from datasets import DatasetDict"
+ "from cyclops.tasks import BinaryTabularClassificationTask"
]
},
{
@@ -198,7 +77,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {
"tags": []
},
@@ -219,22 +98,11 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {
"tags": []
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages/ucimlrepo/fetch.py:97: DtypeWarning:\n",
- "\n",
- "Columns (10) have mixed types. Specify dtype option on import or set low_memory=False.\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"diabetes_130_data = fetch_ucirepo(id=296)\n",
"features = diabetes_130_data[\"data\"][\"features\"]\n",
@@ -245,63 +113,18 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {
"tags": []
},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'uci_id': 296,\n",
- " 'name': 'Diabetes 130-US Hospitals for Years 1999-2008',\n",
- " 'repository_url': 'https://archive.ics.uci.edu/dataset/296/diabetes+130-us+hospitals+for+years+1999-2008',\n",
- " 'data_url': 'https://archive.ics.uci.edu/static/public/296/data.csv',\n",
- " 'abstract': 'The dataset represents ten years (1999-2008) of clinical care at 130 US hospitals and integrated delivery networks. Each row concerns hospital records of patients diagnosed with diabetes, who underwent laboratory, medications, and stayed up to 14 days. The goal is to determine the early readmission of the patient within 30 days of discharge.\\nThe problem is important for the following reasons. Despite high-quality evidence showing improved clinical outcomes for diabetic patients who receive various preventive and therapeutic interventions, many patients do not receive them. This can be partially attributed to arbitrary diabetes management in hospital environments, which fail to attend to glycemic control. Failure to provide proper diabetes care not only increases the managing costs for the hospitals (as the patients are readmitted) but also impacts the morbidity and mortality of the patients, who may face complications associated with diabetes.\\n',\n",
- " 'area': 'Health and Medicine',\n",
- " 'tasks': ['Classification', 'Clustering'],\n",
- " 'characteristics': ['Multivariate'],\n",
- " 'num_instances': 101766,\n",
- " 'num_features': 47,\n",
- " 'feature_types': ['Categorical', 'Integer'],\n",
- " 'demographics': ['Race', 'Gender', 'Age'],\n",
- " 'target_col': ['readmitted'],\n",
- " 'index_col': ['encounter_id', 'patient_nbr'],\n",
- " 'has_missing_values': 'yes',\n",
- " 'missing_values_symbol': 'NaN',\n",
- " 'year_of_dataset_creation': 2014,\n",
- " 'last_updated': 'Mon Feb 26 2024',\n",
- " 'dataset_doi': '10.24432/C5230J',\n",
- " 'creators': ['John Clore', 'Krzysztof Cios', 'Jon DeShazo', 'Beata Strack'],\n",
- " 'intro_paper': {'title': 'Impact of HbA1c Measurement on Hospital Readmission Rates: Analysis of 70,000 Clinical Database Patient Record',\n",
- " 'authors': 'Beata Strack, Jonathan DeShazo, Chris Gennings, Juan Olmo, Sebastian Ventura, Krzysztof Cios, John Clore',\n",
- " 'published_in': 'BioMed Research International, vol. 2014',\n",
- " 'year': 2014,\n",
- " 'url': 'https://www.hindawi.com/journals/bmri/2014/781670/',\n",
- " 'doi': None},\n",
- " 'additional_info': {'summary': 'The dataset represents ten years (1999-2008) of clinical care at 130 US hospitals and integrated delivery networks. It includes over 50 features representing patient and hospital outcomes. Information was extracted from the database for encounters that satisfied the following criteria.\\n(1)\\tIt is an inpatient encounter (a hospital admission).\\n(2)\\tIt is a diabetic encounter, that is, one during which any kind of diabetes was entered into the system as a diagnosis.\\n(3)\\tThe length of stay was at least 1 day and at most 14 days.\\n(4)\\tLaboratory tests were performed during the encounter.\\n(5)\\tMedications were administered during the encounter.\\n\\nThe data contains such attributes as patient number, race, gender, age, admission type, time in hospital, medical specialty of admitting physician, number of lab tests performed, HbA1c test result, diagnosis, number of medications, diabetic medications, number of outpatient, inpatient, and emergency visits in the year before the hospitalization, etc.',\n",
- " 'purpose': None,\n",
- " 'funded_by': None,\n",
- " 'instances_represent': 'The instances represent hospitalized patient records diagnosed with diabetes.',\n",
- " 'recommended_data_splits': 'No recommendation. The standard train-test split could be used. Can use three-way holdout split (i.e., train-validation-test) when doing model selection.',\n",
- " 'sensitive_data': 'Yes. The dataset contains information about the age, gender, and race of the patients.',\n",
- " 'preprocessing_description': None,\n",
- " 'variable_info': 'Detailed description of all the atrributes is provided in Table 1 Beata Strack, Jonathan P. DeShazo, Chris Gennings, Juan L. Olmo, Sebastian Ventura, Krzysztof J. Cios, and John N. Clore, “Impact of HbA1c Measurement on Hospital Readmission Rates: Analysis of 70,000 Clinical Database Patient Records,” BioMed Research International, vol. 2014, Article ID 781670, 11 pages, 2014.\\n\\nhttp://www.hindawi.com/journals/bmri/2014/781670/',\n",
- " 'citation': 'Please cite:\\nBeata Strack, Jonathan P. DeShazo, Chris Gennings, Juan L. Olmo, Sebastian Ventura, Krzysztof J. Cios, and John N. Clore, “Impact of HbA1c Measurement on Hospital Readmission Rates: Analysis of 70,000 Clinical Database Patient Records,” BioMed Research International, vol. 2014, Article ID 781670, 11 pages, 2014.'}}"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"metadata"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {
"tags": []
},
@@ -343,7 +166,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -359,7 +182,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -369,7 +192,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {
"tags": []
},
@@ -384,22 +207,11 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"metadata": {
"tags": []
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "7.960641014352381 outcome\n",
- "0 90409\n",
- "1 11357\n",
- "Name: count, dtype: int64\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"class_counts = df[\"outcome\"].value_counts()\n",
"class_ratio = class_counts[0] / class_counts[1]\n",
@@ -415,7 +227,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"metadata": {
"tags": []
},
@@ -437,19 +249,11 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"metadata": {
"tags": []
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'A1Cresult': 'ordinal', 'age': 'ordinal', 'pioglitazone': 'ordinal', 'num_medications': 'numeric', 'metformin-rosiglitazone': 'binary', 'tolazamide': 'ordinal', 'glipizide': 'ordinal', 'number_inpatient': 'numeric', 'troglitazone': 'binary', 'acarbose': 'ordinal', 'glyburide-metformin': 'ordinal', 'acetohexamide': 'binary', 'chlorpropamide': 'ordinal', 'medical_specialty': 'string', 'max_glu_serum': 'ordinal', 'repaglinide': 'ordinal', 'rosiglitazone': 'ordinal', 'admission_type_id': 'ordinal', 'glimepiride': 'ordinal', 'gender': 'ordinal', 'glipizide-metformin': 'binary', 'num_lab_procedures': 'numeric', 'number_emergency': 'numeric', 'glimepiride-pioglitazone': 'binary', 'nateglinide': 'ordinal', 'discharge_disposition_id': 'numeric', 'payer_code': 'ordinal', 'num_procedures': 'ordinal', 'number_outpatient': 'numeric', 'diag_3': 'string', 'change': 'binary', 'diabetesMed': 'binary', 'miglitol': 'ordinal', 'race': 'ordinal', 'diag_1': 'string', 'outcome': 'binary', 'diag_2': 'string', 'glyburide': 'ordinal', 'metformin': 'ordinal', 'metformin-pioglitazone': 'binary', 'weight': 'ordinal', 'admission_source_id': 'ordinal', 'tolbutamide': 'binary', 'number_diagnoses': 'ordinal', 'insulin': 'ordinal', 'time_in_hospital': 'ordinal'}\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"tab_features = TabularFeatures(\n",
" data=df.reset_index(),\n",
@@ -471,7 +275,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": null,
"metadata": {
"tags": []
},
@@ -488,19 +292,11 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": null,
"metadata": {
"tags": []
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "['discharge_disposition_id', 'num_lab_procedures', 'num_medications', 'number_emergency', 'number_inpatient', 'number_outpatient']\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"numeric_features = sorted((tab_features.features_by_type(\"numeric\")))\n",
"numeric_indices = [\n",
@@ -511,19 +307,11 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": null,
"metadata": {
"tags": []
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "['acetohexamide', 'change', 'diabetesMed', 'glimepiride-pioglitazone', 'glipizide-metformin', 'metformin-pioglitazone', 'metformin-rosiglitazone', 'tolbutamide', 'troglitazone'] ['A1Cresult', 'acarbose', 'admission_source_id', 'admission_type_id', 'age', 'chlorpropamide', 'diag_1', 'diag_2', 'diag_3', 'gender', 'glimepiride', 'glipizide', 'glyburide', 'glyburide-metformin', 'insulin', 'max_glu_serum', 'medical_specialty', 'metformin', 'miglitol', 'nateglinide', 'num_procedures', 'number_diagnoses', 'payer_code', 'pioglitazone', 'race', 'repaglinide', 'rosiglitazone', 'time_in_hospital', 'tolazamide', 'weight']\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"binary_features = sorted(tab_features.features_by_type(\"binary\"))\n",
"binary_features.remove(\"outcome\")\n",
@@ -542,7 +330,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": null,
"metadata": {
"tags": []
},
@@ -572,42 +360,12 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": null,
"metadata": {
"tags": []
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Dataset({\n",
- " features: ['race', 'gender', 'age', 'weight', 'admission_type_id', 'discharge_disposition_id', 'admission_source_id', 'time_in_hospital', 'payer_code', 'medical_specialty', 'num_lab_procedures', 'num_procedures', 'num_medications', 'number_outpatient', 'number_emergency', 'number_inpatient', 'diag_1', 'diag_2', 'diag_3', 'number_diagnoses', 'max_glu_serum', 'A1Cresult', 'metformin', 'repaglinide', 'nateglinide', 'chlorpropamide', 'glimepiride', 'acetohexamide', 'glipizide', 'glyburide', 'tolbutamide', 'pioglitazone', 'rosiglitazone', 'acarbose', 'miglitol', 'troglitazone', 'tolazamide', 'insulin', 'glyburide-metformin', 'glipizide-metformin', 'glimepiride-pioglitazone', 'metformin-rosiglitazone', 'metformin-pioglitazone', 'change', 'diabetesMed', 'outcome'],\n",
- " num_rows: 101766\n",
- "})\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "# upsample the minority class\n",
- "from sklearn.utils import resample\n",
- "\n",
- "# df_majority = df[df.outcome == 0]\n",
- "\n",
- "# df_minority = df[df.outcome == 1]\n",
- "\n",
- "# df_minority_upsampled = resample(\n",
- "# df_minority,\n",
- "# replace=True,\n",
- "# n_samples=len(df_majority),\n",
- "# random_state=RANDOM_SEED,\n",
- "# )\n",
- "\n",
- "# df_upsampled = pd.concat([df_majority, df_minority_upsampled])\n",
- "# df_upsampled = df_upsampled.sample(frac=1, random_state=RANDOM_SEED)\n",
- "# print(df_upsampled.outcome.value_counts())\n",
- "# dataset = Dataset.from_pandas(df_upsampled)\n",
- "\n",
"dataset = Dataset.from_pandas(df)\n",
"dataset.cleanup_cache_files()\n",
"print(dataset)"
@@ -615,7 +373,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": null,
"metadata": {
"tags": []
},
@@ -640,7 +398,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": null,
"metadata": {
"tags": []
},
@@ -661,7 +419,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": null,
"metadata": {
"tags": []
},
@@ -676,22 +434,11 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": null,
"metadata": {
"tags": []
},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "['xgb_classifier']"
- ]
- },
- "execution_count": 21,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"readmission_prediction_task.list_models()"
]
@@ -709,473 +456,11 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": null,
"metadata": {
"tags": []
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2024-07-16 09:18:47,811 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - No validation split was found.\n",
- "2024-07-16 09:20:34,989 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best scale_pos_weight: 7\n",
- "2024-07-16 09:20:34,990 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best reg_lambda: 0\n",
- "2024-07-16 09:20:34,990 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best n_estimators: 100\n",
- "2024-07-16 09:20:34,991 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best max_depth: 5\n",
- "2024-07-16 09:20:34,991 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best learning_rate: 0.1\n",
- "2024-07-16 09:20:34,991 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best gamma: 0\n",
- "2024-07-16 09:20:34,992 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best colsample_bytree: 1\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "
XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
- " colsample_bylevel=None, colsample_bynode=None, colsample_bytree=1,\n",
- " early_stopping_rounds=None, enable_categorical=False,\n",
- " eval_metric='logloss', feature_types=None, gamma=0, gpu_id=None,\n",
- " grow_policy=None, importance_type=None,\n",
- " interaction_constraints=None, learning_rate=0.1, max_bin=None,\n",
- " max_cat_threshold=None, max_cat_to_onehot=None,\n",
- " max_delta_step=None, max_depth=5, max_leaves=None,\n",
- " min_child_weight=3, missing=nan, monotone_constraints=None,\n",
- " n_estimators=100, n_jobs=None, num_parallel_tree=None,\n",
- " predictor=None, random_state=123, ...) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. XGBClassifieriFitted XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
- " colsample_bylevel=None, colsample_bynode=None, colsample_bytree=1,\n",
- " early_stopping_rounds=None, enable_categorical=False,\n",
- " eval_metric='logloss', feature_types=None, gamma=0, gpu_id=None,\n",
- " grow_policy=None, importance_type=None,\n",
- " interaction_constraints=None, learning_rate=0.1, max_bin=None,\n",
- " max_cat_threshold=None, max_cat_to_onehot=None,\n",
- " max_delta_step=None, max_depth=5, max_leaves=None,\n",
- " min_child_weight=3, missing=nan, monotone_constraints=None,\n",
- " n_estimators=100, n_jobs=None, num_parallel_tree=None,\n",
- " predictor=None, random_state=123, ...) "
- ],
- "text/plain": [
- "XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
- " colsample_bylevel=None, colsample_bynode=None, colsample_bytree=1,\n",
- " early_stopping_rounds=None, enable_categorical=False,\n",
- " eval_metric='logloss', feature_types=None, gamma=0, gpu_id=None,\n",
- " grow_policy=None, importance_type=None,\n",
- " interaction_constraints=None, learning_rate=0.1, max_bin=None,\n",
- " max_cat_threshold=None, max_cat_to_onehot=None,\n",
- " max_delta_step=None, max_depth=5, max_leaves=None,\n",
- " min_child_weight=3, missing=nan, monotone_constraints=None,\n",
- " n_estimators=100, n_jobs=None, num_parallel_tree=None,\n",
- " predictor=None, random_state=123, ...)"
- ]
- },
- "execution_count": 22,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"best_model_params = {\n",
" \"n_estimators\": [100, 250, 500],\n",
@@ -1189,8 +474,12 @@
"}\n",
"dataset[\"train\"] = dataset[\"train\"].train_test_split(train_size=0.8, seed=RANDOM_SEED)\n",
"\n",
+ "train_dataset = dataset[\"train\"]\n",
+ "val = train_dataset.pop(\"test\")\n",
+ "train_dataset[\"validation\"] = val\n",
+ "\n",
"readmission_prediction_task.train(\n",
- " dataset[\"train\"],\n",
+ " train_dataset,\n",
" model_name=model_name,\n",
" transforms=preprocessor,\n",
" best_model_params=best_model_params,\n",
@@ -1199,17 +488,9 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'objective': 'binary:logistic', 'use_label_encoder': None, 'base_score': None, 'booster': None, 'callbacks': None, 'colsample_bylevel': None, 'colsample_bynode': None, 'colsample_bytree': 1, 'early_stopping_rounds': None, 'enable_categorical': False, 'eval_metric': 'logloss', 'feature_types': None, 'gamma': 0, 'gpu_id': None, 'grow_policy': None, 'importance_type': None, 'interaction_constraints': None, 'learning_rate': 0.1, 'max_bin': None, 'max_cat_threshold': None, 'max_cat_to_onehot': None, 'max_delta_step': None, 'max_depth': 5, 'max_leaves': None, 'min_child_weight': 3, 'missing': nan, 'monotone_constraints': None, 'n_estimators': 100, 'n_jobs': None, 'num_parallel_tree': None, 'predictor': None, 'random_state': 123, 'reg_alpha': None, 'reg_lambda': 0, 'sampling_method': None, 'scale_pos_weight': 7, 'subsample': None, 'tree_method': None, 'validate_parameters': None, 'verbosity': None, 'seed': 123}\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"model_params = readmission_prediction_task.list_models_params()[model_name]\n",
"print(model_params)"
@@ -1219,91 +500,26 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Prediction\n",
- "\n",
- "The prediction output can be either the whole Hugging Face dataset with the prediction columns added to it or the single column containing the predicted values."
+ "Initialize detectron model with pre-trained weights and training/validation data."
]
},
{
"cell_type": "code",
- "execution_count": 24,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "# y_pred = readmission_prediction_task.predict(\n",
- "# dataset[\"test\"],\n",
- "# model_name=model_name,\n",
- "# transforms=preprocessor,\n",
- "# proba=True,\n",
- "# only_predictions=True,\n",
- "# )\n",
- "# prediction_df = pd.DataFrame(\n",
- "# {\n",
- "# \"y_prob\": [y_pred_i[1] for y_pred_i in y_pred],\n",
- "# \"y_true\": dataset[\"test\"][\"outcome\"],\n",
- "# }\n",
- "# )"
- ]
- },
- {
- "cell_type": "markdown",
+ "execution_count": null,
"metadata": {},
- "source": [
- "## Evaluation\n",
- "\n",
- "Evaluation is done using various evaluation metrics that provide different perspectives on the model's predictive abilities i.e. standard performance metrics and fairness metrics.\n",
- "\n",
- "The standard performance metrics can be created using the `MetricDict` object."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {
- "tags": []
- },
"outputs": [],
"source": [
- "metric_names = [\n",
- " \"binary_accuracy\",\n",
- " \"binary_precision\",\n",
- " \"binary_recall\",\n",
- " \"binary_f1_score\",\n",
- " \"binary_auroc\",\n",
- " \"binary_average_precision\",\n",
- " \"binary_roc_curve\",\n",
- " \"binary_precision_recall_curve\",\n",
- "]\n",
- "metrics = [\n",
- " create_metric(metric_name, experimental=True) for metric_name in metric_names\n",
- "]\n",
- "metric_collection = MetricDict(metrics)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "specificity = create_metric(metric_name=\"binary_specificity\", experimental=True)\n",
- "sensitivity = create_metric(metric_name=\"binary_sensitivity\", experimental=True)\n",
- "\n",
- "fpr = -specificity + 1\n",
- "fnr = -sensitivity + 1\n",
- "\n",
- "ber = (fpr + fnr) / 2\n",
- "\n",
- "fairness_metric_collection = MetricDict(\n",
- " {\n",
- " \"Sensitivity\": sensitivity,\n",
- " \"Specificity\": specificity,\n",
- " \"BER\": ber,\n",
- " },\n",
+ "tester = Detectron(\n",
+ " X_s=dataset[\"train\"],\n",
+ " base_model=readmission_prediction_task.models[\"xgb_classifier\"],\n",
+ " feature_column=features_list,\n",
+ " transforms=preprocessor,\n",
+ " splits_mapping={\"train\": \"train\", \"test\": \"test\"},\n",
+ " sample_size=250,\n",
+ " num_runs=5,\n",
+ " ensemble_size=5,\n",
+ " task=\"binary\",\n",
+ " save_dir=\"detectron\",\n",
")"
]
},
@@ -1311,54 +527,34 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "The evaluate methods outputs the evaluation results and the Hugging Face dataset with the predictions added to it."
+ "Get model health using the training data and all the test data."
]
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "tester = Detectron(X_s=dataset[\"train\"],\n",
- " base_model=readmission_prediction_task.models['xgb_classifier'],\n",
- " feature_column=features_list,\n",
- " transforms=preprocessor,\n",
- " splits_mapping={\"train\": \"train\", \"test\": \"test\"},\n",
- " sample_size=250,\n",
- " num_runs=5,\n",
- " ensemble_size=5,\n",
- " task=\"binary\",\n",
- " save_dir=\"detectron\",\n",
- ")"
+ "results = tester.predict(\n",
+ " X_t=DatasetDict({\"train\": dataset[\"train\"][\"train\"], \"test\": dataset[\"test\"]})\n",
+ ")\n",
+ "print(results[\"model_health\"])"
]
},
{
- "cell_type": "code",
- "execution_count": 28,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0.916030534351145\n"
- ]
- }
- ],
"source": [
- "# get model health on all test data\n",
- "results = tester.predict(X_t = DatasetDict({\"train\": dataset[\"train\"][\"train\"], \"test\": dataset[\"test\"]}))\n",
- "print(results[\"model_health\"])"
+ "Split the test data into multiple bins and plot the model health for each bin."
]
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "# split test data into 20 bins\n",
"test_data = dataset[\"test\"]\n",
"test_data_list = []\n",
"\n",
@@ -1366,950 +562,30 @@
"\n",
"bins = np.array_split(indices, 20)\n",
"\n",
- "for bin in bins:\n",
- " test_data_list.append(test_data.select(bin))"
+ "for b in bins:\n",
+ " test_data_list.append(test_data.select(b))"
]
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "# get model health on all test data bins\n",
"model_health = []\n",
"for data in test_data_list:\n",
- " results = tester.predict(X_t=DatasetDict({\"train\": dataset[\"train\"][\"train\"], \"test\": data}))\n",
+ " results = tester.predict(\n",
+ " X_t=DatasetDict({\"train\": dataset[\"train\"][\"train\"], \"test\": data})\n",
+ " )\n",
" model_health.append(results[\"model_health\"])"
]
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.plotly.v1+json": {
- "config": {
- "plotlyServerURL": "https://plot.ly"
- },
- "data": [
- {
- "hovertemplate": "bin=%{x} model_health=%{y} ",
- "legendgroup": "",
- "line": {
- "color": "#636efa",
- "dash": "solid"
- },
- "marker": {
- "symbol": "circle"
- },
- "mode": "lines",
- "name": "",
- "orientation": "v",
- "showlegend": false,
- "type": "scatter",
- "x": [
- 0,
- 1,
- 2,
- 3,
- 4,
- 5,
- 6,
- 7,
- 8,
- 9,
- 10,
- 11,
- 12,
- 13,
- 14,
- 15,
- 16,
- 17,
- 18,
- 19
- ],
- "xaxis": "x",
- "y": [
- 1,
- 1,
- 0.9541984732824428,
- 1,
- 1,
- 0.9923664122137404,
- 0.8015267175572519,
- 1,
- 0.5725190839694657,
- 0.6870229007633588,
- 1,
- 0.8015267175572519,
- 1,
- 0.9923664122137404,
- 0.916030534351145,
- 0.7633587786259542,
- 0.9923664122137404,
- 0.8015267175572519,
- 1,
- 1
- ],
- "yaxis": "y"
- }
- ],
- "layout": {
- "legend": {
- "tracegroupgap": 0
- },
- "template": {
- "data": {
- "bar": [
- {
- "error_x": {
- "color": "#2a3f5f"
- },
- "error_y": {
- "color": "#2a3f5f"
- },
- "marker": {
- "line": {
- "color": "#E5ECF6",
- "width": 0.5
- },
- "pattern": {
- "fillmode": "overlay",
- "size": 10,
- "solidity": 0.2
- }
- },
- "type": "bar"
- }
- ],
- "barpolar": [
- {
- "marker": {
- "line": {
- "color": "#E5ECF6",
- "width": 0.5
- },
- "pattern": {
- "fillmode": "overlay",
- "size": 10,
- "solidity": 0.2
- }
- },
- "type": "barpolar"
- }
- ],
- "carpet": [
- {
- "aaxis": {
- "endlinecolor": "#2a3f5f",
- "gridcolor": "white",
- "linecolor": "white",
- "minorgridcolor": "white",
- "startlinecolor": "#2a3f5f"
- },
- "baxis": {
- "endlinecolor": "#2a3f5f",
- "gridcolor": "white",
- "linecolor": "white",
- "minorgridcolor": "white",
- "startlinecolor": "#2a3f5f"
- },
- "type": "carpet"
- }
- ],
- "choropleth": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "type": "choropleth"
- }
- ],
- "contour": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "colorscale": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "type": "contour"
- }
- ],
- "contourcarpet": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "type": "contourcarpet"
- }
- ],
- "heatmap": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "colorscale": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "type": "heatmap"
- }
- ],
- "heatmapgl": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "colorscale": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "type": "heatmapgl"
- }
- ],
- "histogram": [
- {
- "marker": {
- "pattern": {
- "fillmode": "overlay",
- "size": 10,
- "solidity": 0.2
- }
- },
- "type": "histogram"
- }
- ],
- "histogram2d": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "colorscale": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "type": "histogram2d"
- }
- ],
- "histogram2dcontour": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "colorscale": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "type": "histogram2dcontour"
- }
- ],
- "mesh3d": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "type": "mesh3d"
- }
- ],
- "parcoords": [
- {
- "line": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "parcoords"
- }
- ],
- "pie": [
- {
- "automargin": true,
- "type": "pie"
- }
- ],
- "scatter": [
- {
- "fillpattern": {
- "fillmode": "overlay",
- "size": 10,
- "solidity": 0.2
- },
- "type": "scatter"
- }
- ],
- "scatter3d": [
- {
- "line": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scatter3d"
- }
- ],
- "scattercarpet": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scattercarpet"
- }
- ],
- "scattergeo": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scattergeo"
- }
- ],
- "scattergl": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scattergl"
- }
- ],
- "scattermapbox": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scattermapbox"
- }
- ],
- "scatterpolar": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scatterpolar"
- }
- ],
- "scatterpolargl": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scatterpolargl"
- }
- ],
- "scatterternary": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scatterternary"
- }
- ],
- "surface": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "colorscale": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "type": "surface"
- }
- ],
- "table": [
- {
- "cells": {
- "fill": {
- "color": "#EBF0F8"
- },
- "line": {
- "color": "white"
- }
- },
- "header": {
- "fill": {
- "color": "#C8D4E3"
- },
- "line": {
- "color": "white"
- }
- },
- "type": "table"
- }
- ]
- },
- "layout": {
- "annotationdefaults": {
- "arrowcolor": "#2a3f5f",
- "arrowhead": 0,
- "arrowwidth": 1
- },
- "autotypenumbers": "strict",
- "coloraxis": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "colorscale": {
- "diverging": [
- [
- 0,
- "#8e0152"
- ],
- [
- 0.1,
- "#c51b7d"
- ],
- [
- 0.2,
- "#de77ae"
- ],
- [
- 0.3,
- "#f1b6da"
- ],
- [
- 0.4,
- "#fde0ef"
- ],
- [
- 0.5,
- "#f7f7f7"
- ],
- [
- 0.6,
- "#e6f5d0"
- ],
- [
- 0.7,
- "#b8e186"
- ],
- [
- 0.8,
- "#7fbc41"
- ],
- [
- 0.9,
- "#4d9221"
- ],
- [
- 1,
- "#276419"
- ]
- ],
- "sequential": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "sequentialminus": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ]
- },
- "colorway": [
- "#636efa",
- "#EF553B",
- "#00cc96",
- "#ab63fa",
- "#FFA15A",
- "#19d3f3",
- "#FF6692",
- "#B6E880",
- "#FF97FF",
- "#FECB52"
- ],
- "font": {
- "color": "#2a3f5f"
- },
- "geo": {
- "bgcolor": "white",
- "lakecolor": "white",
- "landcolor": "#E5ECF6",
- "showlakes": true,
- "showland": true,
- "subunitcolor": "white"
- },
- "hoverlabel": {
- "align": "left"
- },
- "hovermode": "closest",
- "mapbox": {
- "style": "light"
- },
- "paper_bgcolor": "white",
- "plot_bgcolor": "#E5ECF6",
- "polar": {
- "angularaxis": {
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": ""
- },
- "bgcolor": "#E5ECF6",
- "radialaxis": {
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": ""
- }
- },
- "scene": {
- "xaxis": {
- "backgroundcolor": "#E5ECF6",
- "gridcolor": "white",
- "gridwidth": 2,
- "linecolor": "white",
- "showbackground": true,
- "ticks": "",
- "zerolinecolor": "white"
- },
- "yaxis": {
- "backgroundcolor": "#E5ECF6",
- "gridcolor": "white",
- "gridwidth": 2,
- "linecolor": "white",
- "showbackground": true,
- "ticks": "",
- "zerolinecolor": "white"
- },
- "zaxis": {
- "backgroundcolor": "#E5ECF6",
- "gridcolor": "white",
- "gridwidth": 2,
- "linecolor": "white",
- "showbackground": true,
- "ticks": "",
- "zerolinecolor": "white"
- }
- },
- "shapedefaults": {
- "line": {
- "color": "#2a3f5f"
- }
- },
- "ternary": {
- "aaxis": {
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": ""
- },
- "baxis": {
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": ""
- },
- "bgcolor": "#E5ECF6",
- "caxis": {
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": ""
- }
- },
- "title": {
- "x": 0.05
- },
- "xaxis": {
- "automargin": true,
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": "",
- "title": {
- "standoff": 15
- },
- "zerolinecolor": "white",
- "zerolinewidth": 2
- },
- "yaxis": {
- "automargin": true,
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": "",
- "title": {
- "standoff": 15
- },
- "zerolinecolor": "white",
- "zerolinewidth": 2
- }
- }
- },
- "title": {
- "text": "Model Health"
- },
- "xaxis": {
- "anchor": "y",
- "domain": [
- 0,
- 1
- ],
- "title": {
- "text": "bin"
- }
- },
- "yaxis": {
- "anchor": "x",
- "domain": [
- 0,
- 1
- ],
- "title": {
- "text": "model_health"
- }
- }
- }
- }
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "outputs": [],
"source": [
- "# use plotly to visualize the model health over the bins\n",
"model_health_df = pd.DataFrame(model_health, columns=[\"model_health\"])\n",
"\n",
"model_health_df[\"bin\"] = np.arange(0, len(model_health_df))\n",
diff --git a/readmission_prediction_detectron.ipynb b/readmission_prediction_detectron.ipynb
deleted file mode 100644
index 7be2dd414..000000000
--- a/readmission_prediction_detectron.ipynb
+++ /dev/null
@@ -1,2343 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Readmission Prediction\n",
- "\n",
- "This notebook showcases readmission prediction on the [Diabetes 130-US Hospitals for Years 1999-2008](https://archive.ics.uci.edu/dataset/296/diabetes+130-us+hospitals+for+years+1999-2008) using CyclOps. The task is formulated as a binary classification task, where we predict the probability of early readmission of the patient within 30 days of discharge."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "tags": []
- },
- "source": [
- "## Install libraries"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Requirement already satisfied: pycyclops[xgboost] in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (0.2.9)\n",
- "Requirement already satisfied: Jinja2<4.0.0,>=3.1.3 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (3.1.4)\n",
- "Requirement already satisfied: array-api-compat==1.6 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.6)\n",
- "Requirement already satisfied: datasets<3.0.0,>=2.15.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (2.19.0)\n",
- "Requirement already satisfied: hydra-core<2.0.0,>=1.2.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.3.2)\n",
- "Requirement already satisfied: kaleido==0.2.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (0.2.1)\n",
- "Requirement already satisfied: matplotlib<4.0.0,>=3.8.3 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (3.8.3)\n",
- "Requirement already satisfied: numpy<2.0.0,>=1.24.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.24.4)\n",
- "Requirement already satisfied: pandas<3.0,>=2.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (2.1.4)\n",
- "Requirement already satisfied: pillow<11.0.0,>=10.0.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (10.3.0)\n",
- "Requirement already satisfied: plotly<6.0.0,>=5.7.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (5.18.0)\n",
- "Requirement already satisfied: psutil<6.0.0,>=5.9.4 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (5.9.7)\n",
- "Requirement already satisfied: pyarrow<15.0.0,>=14.0.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (14.0.2)\n",
- "Requirement already satisfied: pybtex<0.25.0,>=0.24.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (0.24.0)\n",
- "Requirement already satisfied: pydantic<2.0.0,>=1.10.11 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.10.13)\n",
- "Requirement already satisfied: scikit-learn<2.0.0,>=1.4.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.5.0)\n",
- "Requirement already satisfied: scipy<2.0.0,>=1.11.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.13.0rc1)\n",
- "Requirement already satisfied: scour<0.39.0,>=0.38.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (0.38.2)\n",
- "Requirement already satisfied: spdx-tools<0.9.0,>=0.8.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (0.8.2)\n",
- "Requirement already satisfied: xgboost<2.0.0,>=1.5.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pycyclops[xgboost]) (1.7.6)\n",
- "Requirement already satisfied: filelock in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (3.13.1)\n",
- "Requirement already satisfied: pyarrow-hotfix in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (0.6)\n",
- "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (0.3.7)\n",
- "Requirement already satisfied: requests>=2.19.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (2.32.0)\n",
- "Requirement already satisfied: tqdm>=4.62.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (4.66.4)\n",
- "Requirement already satisfied: xxhash in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (3.4.1)\n",
- "Requirement already satisfied: multiprocess in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (0.70.15)\n",
- "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (2023.10.0)\n",
- "Requirement already satisfied: aiohttp in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (3.9.5)\n",
- "Requirement already satisfied: huggingface-hub>=0.21.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (0.22.2)\n",
- "Requirement already satisfied: packaging in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (23.2)\n",
- "Requirement already satisfied: pyyaml>=5.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (6.0.1)\n",
- "Requirement already satisfied: omegaconf<2.4,>=2.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from hydra-core<2.0.0,>=1.2.0->pycyclops[xgboost]) (2.3.0)\n",
- "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from hydra-core<2.0.0,>=1.2.0->pycyclops[xgboost]) (4.9.3)\n",
- "Requirement already satisfied: MarkupSafe>=2.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from Jinja2<4.0.0,>=3.1.3->pycyclops[xgboost]) (2.1.3)\n",
- "Requirement already satisfied: contourpy>=1.0.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (1.1.0)\n",
- "Requirement already satisfied: cycler>=0.10 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (0.12.1)\n",
- "Requirement already satisfied: fonttools>=4.22.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (4.47.0)\n",
- "Requirement already satisfied: kiwisolver>=1.3.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (1.4.5)\n",
- "Requirement already satisfied: pyparsing>=2.3.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (3.1.1)\n",
- "Requirement already satisfied: python-dateutil>=2.7 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (2.8.2)\n",
- "Requirement already satisfied: importlib-resources>=3.2.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (6.1.1)\n",
- "Requirement already satisfied: pytz>=2020.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas<3.0,>=2.1->pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (2023.3.post1)\n",
- "Requirement already satisfied: tzdata>=2022.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas<3.0,>=2.1->pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (2023.3)\n",
- "Requirement already satisfied: bottleneck>=1.3.4 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (1.3.8)\n",
- "Requirement already satisfied: numba>=0.55.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (0.57.1)\n",
- "Requirement already satisfied: numexpr>=2.8.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (2.10.0)\n",
- "Requirement already satisfied: tenacity>=6.2.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from plotly<6.0.0,>=5.7.0->pycyclops[xgboost]) (8.2.3)\n",
- "Requirement already satisfied: latexcodec>=1.0.4 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pybtex<0.25.0,>=0.24.0->pycyclops[xgboost]) (2.0.1)\n",
- "Requirement already satisfied: six in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pybtex<0.25.0,>=0.24.0->pycyclops[xgboost]) (1.16.0)\n",
- "Requirement already satisfied: typing-extensions>=4.2.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pydantic<2.0.0,>=1.10.11->pycyclops[xgboost]) (4.9.0)\n",
- "Requirement already satisfied: joblib>=1.2.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from scikit-learn<2.0.0,>=1.4.0->pycyclops[xgboost]) (1.3.2)\n",
- "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from scikit-learn<2.0.0,>=1.4.0->pycyclops[xgboost]) (3.2.0)\n",
- "Requirement already satisfied: click in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (8.1.7)\n",
- "Requirement already satisfied: xmltodict in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (0.13.0)\n",
- "Requirement already satisfied: rdflib in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (7.0.0)\n",
- "Requirement already satisfied: beartype in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (0.16.4)\n",
- "Requirement already satisfied: uritools in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (4.0.2)\n",
- "Requirement already satisfied: license-expression in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (30.2.0)\n",
- "Requirement already satisfied: ply in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (3.11)\n",
- "Requirement already satisfied: semantic-version in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (2.10.0)\n",
- "Requirement already satisfied: aiosignal>=1.1.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from aiohttp->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (1.3.1)\n",
- "Requirement already satisfied: attrs>=17.3.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from aiohttp->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (23.1.0)\n",
- "Requirement already satisfied: frozenlist>=1.1.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from aiohttp->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (1.4.1)\n",
- "Requirement already satisfied: multidict<7.0,>=4.5 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from aiohttp->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (6.0.4)\n",
- "Requirement already satisfied: yarl<2.0,>=1.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from aiohttp->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (1.9.4)\n",
- "Requirement already satisfied: async-timeout<5.0,>=4.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from aiohttp->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (4.0.3)\n",
- "Requirement already satisfied: zipp>=3.1.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from importlib-resources>=3.2.0->matplotlib<4.0.0,>=3.8.3->pycyclops[xgboost]) (3.17.0)\n",
- "Requirement already satisfied: llvmlite<0.41,>=0.40.0dev0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from numba>=0.55.2->pandas[performance]<3.0,>=2.1->pycyclops[xgboost]) (0.40.1)\n",
- "Requirement already satisfied: charset-normalizer<4,>=2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from requests>=2.19.0->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (3.3.2)\n",
- "Requirement already satisfied: idna<4,>=2.5 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from requests>=2.19.0->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (3.7)\n",
- "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from requests>=2.19.0->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (2.2.2)\n",
- "Requirement already satisfied: certifi>=2017.4.17 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from requests>=2.19.0->datasets<3.0.0,>=2.15.0->pycyclops[xgboost]) (2023.11.17)\n",
- "Requirement already satisfied: boolean.py>=4.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from license-expression->spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (4.0)\n",
- "Requirement already satisfied: isodate<0.7.0,>=0.6.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from rdflib->spdx-tools<0.9.0,>=0.8.1->pycyclops[xgboost]) (0.6.1)\n",
- "\n",
- "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n",
- "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
- "Requirement already satisfied: ucimlrepo in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (0.0.7)\n",
- "Requirement already satisfied: pandas>=1.0.0 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from ucimlrepo) (2.1.4)\n",
- "Requirement already satisfied: certifi>=2020.12.5 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from ucimlrepo) (2023.11.17)\n",
- "Requirement already satisfied: numpy<2,>=1.22.4 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas>=1.0.0->ucimlrepo) (1.24.4)\n",
- "Requirement already satisfied: python-dateutil>=2.8.2 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas>=1.0.0->ucimlrepo) (2.8.2)\n",
- "Requirement already satisfied: pytz>=2020.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas>=1.0.0->ucimlrepo) (2023.3.post1)\n",
- "Requirement already satisfied: tzdata>=2022.1 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from pandas>=1.0.0->ucimlrepo) (2023.3)\n",
- "Requirement already satisfied: six>=1.5 in /home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas>=1.0.0->ucimlrepo) (1.16.0)\n",
- "\n",
- "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n",
- "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
- ]
- }
- ],
- "source": [
- "!pip install pycyclops[xgboost]\n",
- "!pip install ucimlrepo"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Import Libraries"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n"
- ]
- }
- ],
- "source": [
- "\"\"\"Readmission prediction.\"\"\"\n",
- "\n",
- "# ruff: noqa: E402\n",
- "\n",
- "import copy\n",
- "import inspect\n",
- "from datetime import date\n",
- "\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "import plotly.express as px\n",
- "from datasets import Dataset\n",
- "from datasets.features import ClassLabel\n",
- "from sklearn.compose import ColumnTransformer\n",
- "from sklearn.impute import SimpleImputer\n",
- "from sklearn.pipeline import Pipeline\n",
- "from sklearn.preprocessing import MinMaxScaler, OneHotEncoder\n",
- "from ucimlrepo import fetch_ucirepo\n",
- "\n",
- "from cyclops.data.df.feature import TabularFeatures\n",
- "from cyclops.data.slicer import SliceSpec\n",
- "from cyclops.evaluate.fairness import FairnessConfig # noqa: E402\n",
- "from cyclops.evaluate.metrics import create_metric\n",
- "from cyclops.evaluate.metrics.experimental.functional import (\n",
- " binary_npv,\n",
- " binary_ppv,\n",
- " binary_roc,\n",
- ")\n",
- "from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict\n",
- "from cyclops.models.catalog import create_model\n",
- "from cyclops.report import ModelCardReport\n",
- "from cyclops.report.plot.classification import ClassificationPlotter\n",
- "from cyclops.report.utils import flatten_results_dict\n",
- "from cyclops.tasks import BinaryTabularClassificationTask\n",
- "\n",
- "from cyclops.monitor.tester import Detectron\n",
- "from datasets import DatasetDict"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Constants"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "RANDOM_SEED = 85\n",
- "NAN_THRESHOLD = 0.75\n",
- "TRAIN_SIZE = 0.05\n",
- "EVAL_NUM = 3"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Data Loading"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/akore/.cache/pypoetry/virtualenvs/pycyclops-4J2PL5I8-py3.9/lib/python3.9/site-packages/ucimlrepo/fetch.py:97: DtypeWarning:\n",
- "\n",
- "Columns (10) have mixed types. Specify dtype option on import or set low_memory=False.\n",
- "\n"
- ]
- }
- ],
- "source": [
- "diabetes_130_data = fetch_ucirepo(id=296)\n",
- "features = diabetes_130_data[\"data\"][\"features\"]\n",
- "targets = diabetes_130_data[\"data\"][\"targets\"]\n",
- "metadata = diabetes_130_data[\"metadata\"]\n",
- "variables = diabetes_130_data[\"variables\"]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'uci_id': 296,\n",
- " 'name': 'Diabetes 130-US Hospitals for Years 1999-2008',\n",
- " 'repository_url': 'https://archive.ics.uci.edu/dataset/296/diabetes+130-us+hospitals+for+years+1999-2008',\n",
- " 'data_url': 'https://archive.ics.uci.edu/static/public/296/data.csv',\n",
- " 'abstract': 'The dataset represents ten years (1999-2008) of clinical care at 130 US hospitals and integrated delivery networks. Each row concerns hospital records of patients diagnosed with diabetes, who underwent laboratory, medications, and stayed up to 14 days. The goal is to determine the early readmission of the patient within 30 days of discharge.\\nThe problem is important for the following reasons. Despite high-quality evidence showing improved clinical outcomes for diabetic patients who receive various preventive and therapeutic interventions, many patients do not receive them. This can be partially attributed to arbitrary diabetes management in hospital environments, which fail to attend to glycemic control. Failure to provide proper diabetes care not only increases the managing costs for the hospitals (as the patients are readmitted) but also impacts the morbidity and mortality of the patients, who may face complications associated with diabetes.\\n',\n",
- " 'area': 'Health and Medicine',\n",
- " 'tasks': ['Classification', 'Clustering'],\n",
- " 'characteristics': ['Multivariate'],\n",
- " 'num_instances': 101766,\n",
- " 'num_features': 47,\n",
- " 'feature_types': ['Categorical', 'Integer'],\n",
- " 'demographics': ['Race', 'Gender', 'Age'],\n",
- " 'target_col': ['readmitted'],\n",
- " 'index_col': ['encounter_id', 'patient_nbr'],\n",
- " 'has_missing_values': 'yes',\n",
- " 'missing_values_symbol': 'NaN',\n",
- " 'year_of_dataset_creation': 2014,\n",
- " 'last_updated': 'Mon Feb 26 2024',\n",
- " 'dataset_doi': '10.24432/C5230J',\n",
- " 'creators': ['John Clore', 'Krzysztof Cios', 'Jon DeShazo', 'Beata Strack'],\n",
- " 'intro_paper': {'title': 'Impact of HbA1c Measurement on Hospital Readmission Rates: Analysis of 70,000 Clinical Database Patient Record',\n",
- " 'authors': 'Beata Strack, Jonathan DeShazo, Chris Gennings, Juan Olmo, Sebastian Ventura, Krzysztof Cios, John Clore',\n",
- " 'published_in': 'BioMed Research International, vol. 2014',\n",
- " 'year': 2014,\n",
- " 'url': 'https://www.hindawi.com/journals/bmri/2014/781670/',\n",
- " 'doi': None},\n",
- " 'additional_info': {'summary': 'The dataset represents ten years (1999-2008) of clinical care at 130 US hospitals and integrated delivery networks. It includes over 50 features representing patient and hospital outcomes. Information was extracted from the database for encounters that satisfied the following criteria.\\n(1)\\tIt is an inpatient encounter (a hospital admission).\\n(2)\\tIt is a diabetic encounter, that is, one during which any kind of diabetes was entered into the system as a diagnosis.\\n(3)\\tThe length of stay was at least 1 day and at most 14 days.\\n(4)\\tLaboratory tests were performed during the encounter.\\n(5)\\tMedications were administered during the encounter.\\n\\nThe data contains such attributes as patient number, race, gender, age, admission type, time in hospital, medical specialty of admitting physician, number of lab tests performed, HbA1c test result, diagnosis, number of medications, diabetic medications, number of outpatient, inpatient, and emergency visits in the year before the hospitalization, etc.',\n",
- " 'purpose': None,\n",
- " 'funded_by': None,\n",
- " 'instances_represent': 'The instances represent hospitalized patient records diagnosed with diabetes.',\n",
- " 'recommended_data_splits': 'No recommendation. The standard train-test split could be used. Can use three-way holdout split (i.e., train-validation-test) when doing model selection.',\n",
- " 'sensitive_data': 'Yes. The dataset contains information about the age, gender, and race of the patients.',\n",
- " 'preprocessing_description': None,\n",
- " 'variable_info': 'Detailed description of all the atrributes is provided in Table 1 Beata Strack, Jonathan P. DeShazo, Chris Gennings, Juan L. Olmo, Sebastian Ventura, Krzysztof J. Cios, and John N. Clore, “Impact of HbA1c Measurement on Hospital Readmission Rates: Analysis of 70,000 Clinical Database Patient Records,” BioMed Research International, vol. 2014, Article ID 781670, 11 pages, 2014.\\n\\nhttp://www.hindawi.com/journals/bmri/2014/781670/',\n",
- " 'citation': 'Please cite:\\nBeata Strack, Jonathan P. DeShazo, Chris Gennings, Juan L. Olmo, Sebastian Ventura, Krzysztof J. Cios, and John N. Clore, “Impact of HbA1c Measurement on Hospital Readmission Rates: Analysis of 70,000 Clinical Database Patient Records,” BioMed Research International, vol. 2014, Article ID 781670, 11 pages, 2014.'}}"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "metadata"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "def transform_label(value):\n",
- " \"\"\"Transform string labels of readmission into 0/1 binary labels.\n",
- "\n",
- " Parameters\n",
- " ----------\n",
- " value: str\n",
- " Input value\n",
- "\n",
- " Returns\n",
- " -------\n",
- " int\n",
- " 0 if not readmitted or if greater than 30 days, 1 if less than 30 days\n",
- "\n",
- " \"\"\"\n",
- " if value in [\"NO\", \">30\"]:\n",
- " return 0\n",
- " if value == \"<30\":\n",
- " return 1\n",
- "\n",
- " raise ValueError(\"Unexpected value for readmission!\")\n",
- "\n",
- "\n",
- "df = features\n",
- "targets[\"readmitted\"] = targets[\"readmitted\"].apply(transform_label)\n",
- "df[\"readmitted\"] = targets"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Choose a small subset for modelling"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [],
- "source": [
- "df = df[0:1000000]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Remove features that are NaNs or have just a single unique value"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "df[\"outcome\"] = df[\"readmitted\"].astype(\"int\")\n",
- "df = df.drop(columns=[\"readmitted\"])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "features_to_remove = []\n",
- "for col in df:\n",
- " if len(df[col].value_counts()) <= 1:\n",
- " features_to_remove.append(col)\n",
- "df = df.drop(columns=features_to_remove)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "7.960641014352381 outcome\n",
- "0 90409\n",
- "1 11357\n",
- "Name: count, dtype: int64\n"
- ]
- }
- ],
- "source": [
- "class_counts = df[\"outcome\"].value_counts()\n",
- "class_ratio = class_counts[0] / class_counts[1]\n",
- "print(class_ratio, class_counts)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "From the features in the dataset, we select all of them to train the model!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "features_list = list(df.columns)\n",
- "features_list.remove(\"outcome\")\n",
- "features_list = sorted(features_list)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Identifying feature types\n",
- "\n",
- "Cyclops `TabularFeatures` class helps to identify feature types, an essential step before preprocessing the data. Understanding feature types (numerical/categorical/binary) allows us to apply appropriate preprocessing steps for each type."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'A1Cresult': 'ordinal', 'age': 'ordinal', 'pioglitazone': 'ordinal', 'num_medications': 'numeric', 'metformin-rosiglitazone': 'binary', 'tolazamide': 'ordinal', 'glipizide': 'ordinal', 'number_inpatient': 'numeric', 'troglitazone': 'binary', 'acarbose': 'ordinal', 'glyburide-metformin': 'ordinal', 'acetohexamide': 'binary', 'chlorpropamide': 'ordinal', 'medical_specialty': 'string', 'max_glu_serum': 'ordinal', 'repaglinide': 'ordinal', 'rosiglitazone': 'ordinal', 'admission_type_id': 'ordinal', 'glimepiride': 'ordinal', 'gender': 'ordinal', 'glipizide-metformin': 'binary', 'num_lab_procedures': 'numeric', 'number_emergency': 'numeric', 'glimepiride-pioglitazone': 'binary', 'nateglinide': 'ordinal', 'discharge_disposition_id': 'numeric', 'payer_code': 'ordinal', 'num_procedures': 'ordinal', 'number_outpatient': 'numeric', 'diag_3': 'string', 'change': 'binary', 'diabetesMed': 'binary', 'miglitol': 'ordinal', 'race': 'ordinal', 'diag_1': 'string', 'outcome': 'binary', 'diag_2': 'string', 'glyburide': 'ordinal', 'metformin': 'ordinal', 'metformin-pioglitazone': 'binary', 'weight': 'ordinal', 'admission_source_id': 'ordinal', 'tolbutamide': 'binary', 'number_diagnoses': 'ordinal', 'insulin': 'ordinal', 'time_in_hospital': 'ordinal'}\n"
- ]
- }
- ],
- "source": [
- "tab_features = TabularFeatures(\n",
- " data=df.reset_index(),\n",
- " features=features_list,\n",
- " by=\"index\",\n",
- " targets=\"outcome\",\n",
- ")\n",
- "print(tab_features.types)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Creating data preprocessors\n",
- "\n",
- "We create a data preprocessor using sklearn's ColumnTransformer. This helps in applying different preprocessing steps to different columns in the dataframe. For instance, binary features might be processed differently from numeric features."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "numeric_transformer = Pipeline(\n",
- " steps=[(\"imputer\", SimpleImputer(strategy=\"mean\")), (\"scaler\", MinMaxScaler())],\n",
- ")\n",
- "\n",
- "binary_transformer = Pipeline(\n",
- " steps=[(\"imputer\", SimpleImputer(strategy=\"most_frequent\"))],\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "['discharge_disposition_id', 'num_lab_procedures', 'num_medications', 'number_emergency', 'number_inpatient', 'number_outpatient']\n"
- ]
- }
- ],
- "source": [
- "numeric_features = sorted((tab_features.features_by_type(\"numeric\")))\n",
- "numeric_indices = [\n",
- " df[features_list].columns.get_loc(column) for column in numeric_features\n",
- "]\n",
- "print(numeric_features)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "['acetohexamide', 'change', 'diabetesMed', 'glimepiride-pioglitazone', 'glipizide-metformin', 'metformin-pioglitazone', 'metformin-rosiglitazone', 'tolbutamide', 'troglitazone'] ['A1Cresult', 'acarbose', 'admission_source_id', 'admission_type_id', 'age', 'chlorpropamide', 'diag_1', 'diag_2', 'diag_3', 'gender', 'glimepiride', 'glipizide', 'glyburide', 'glyburide-metformin', 'insulin', 'max_glu_serum', 'medical_specialty', 'metformin', 'miglitol', 'nateglinide', 'num_procedures', 'number_diagnoses', 'payer_code', 'pioglitazone', 'race', 'repaglinide', 'rosiglitazone', 'time_in_hospital', 'tolazamide', 'weight']\n"
- ]
- }
- ],
- "source": [
- "binary_features = sorted(tab_features.features_by_type(\"binary\"))\n",
- "binary_features.remove(\"outcome\")\n",
- "ordinal_features = sorted(\n",
- " tab_features.features_by_type(\"ordinal\")\n",
- " + [\"medical_specialty\", \"diag_1\", \"diag_2\", \"diag_3\"]\n",
- ")\n",
- "binary_indices = [\n",
- " df[features_list].columns.get_loc(column) for column in binary_features\n",
- "]\n",
- "ordinal_indices = [\n",
- " df[features_list].columns.get_loc(column) for column in ordinal_features\n",
- "]\n",
- "print(binary_features, ordinal_features)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "preprocessor = ColumnTransformer(\n",
- " transformers=[\n",
- " (\"num\", numeric_transformer, numeric_indices),\n",
- " (\n",
- " \"onehot\",\n",
- " OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False),\n",
- " binary_indices + ordinal_indices,\n",
- " ),\n",
- " ],\n",
- " remainder=\"passthrough\",\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Creating Hugging Face Dataset\n",
- "\n",
- "We convert our processed Pandas dataframe into a Hugging Face dataset, a powerful and easy-to-use data format which is also compatible with CyclOps models and evaluator modules. The dataset is then split to train and test sets."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Dataset({\n",
- " features: ['race', 'gender', 'age', 'weight', 'admission_type_id', 'discharge_disposition_id', 'admission_source_id', 'time_in_hospital', 'payer_code', 'medical_specialty', 'num_lab_procedures', 'num_procedures', 'num_medications', 'number_outpatient', 'number_emergency', 'number_inpatient', 'diag_1', 'diag_2', 'diag_3', 'number_diagnoses', 'max_glu_serum', 'A1Cresult', 'metformin', 'repaglinide', 'nateglinide', 'chlorpropamide', 'glimepiride', 'acetohexamide', 'glipizide', 'glyburide', 'tolbutamide', 'pioglitazone', 'rosiglitazone', 'acarbose', 'miglitol', 'troglitazone', 'tolazamide', 'insulin', 'glyburide-metformin', 'glipizide-metformin', 'glimepiride-pioglitazone', 'metformin-rosiglitazone', 'metformin-pioglitazone', 'change', 'diabetesMed', 'outcome'],\n",
- " num_rows: 101766\n",
- "})\n"
- ]
- }
- ],
- "source": [
- "# upsample the minority class\n",
- "from sklearn.utils import resample\n",
- "\n",
- "# df_majority = df[df.outcome == 0]\n",
- "\n",
- "# df_minority = df[df.outcome == 1]\n",
- "\n",
- "# df_minority_upsampled = resample(\n",
- "# df_minority,\n",
- "# replace=True,\n",
- "# n_samples=len(df_majority),\n",
- "# random_state=RANDOM_SEED,\n",
- "# )\n",
- "\n",
- "# df_upsampled = pd.concat([df_majority, df_minority_upsampled])\n",
- "# df_upsampled = df_upsampled.sample(frac=1, random_state=RANDOM_SEED)\n",
- "# print(df_upsampled.outcome.value_counts())\n",
- "# dataset = Dataset.from_pandas(df_upsampled)\n",
- "\n",
- "dataset = Dataset.from_pandas(df)\n",
- "dataset.cleanup_cache_files()\n",
- "print(dataset)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "dataset = dataset.cast_column(\"outcome\", ClassLabel(num_classes=2))\n",
- "dataset = dataset.train_test_split(\n",
- " train_size=TRAIN_SIZE,\n",
- " stratify_by_column=\"outcome\",\n",
- " seed=RANDOM_SEED,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Model Creation\n",
- "\n",
- "CyclOps model registry allows for straightforward creation and selection of models. This registry maintains a list of pre-configured models, which can be instantiated with a single line of code. Here we use a SGD classifier to fit a logisitic regression model. The model configurations can be passed to `create_model` based on the sklearn parameters for SGDClassifier."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "model_name = \"xgb_classifier\"\n",
- "model = create_model(model_name, random_state=123)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Task Creation\n",
- "\n",
- "We use Cyclops tasks to define our model's task (in this case, readmission prediction), train the model, make predictions, and evaluate performance. Cyclops task classes encapsulate the entire ML pipeline into a single, cohesive structure, making the process smooth and easy to manage."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "readmission_prediction_task = BinaryTabularClassificationTask(\n",
- " {model_name: model},\n",
- " task_features=features_list,\n",
- " task_target=\"outcome\",\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "['xgb_classifier']"
- ]
- },
- "execution_count": 21,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "readmission_prediction_task.list_models()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Training\n",
- "\n",
- "If `best_model_params` is passed to the `train` method, the best model will be selected after the hyperparameter search. The parameters in `best_model_params` indicate the values to create the parameters grid.\n",
- "\n",
- "Note that the data preprocessor needs to be passed to the tasks methods if the Hugging Face dataset is not already preprocessed. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2024-07-16 09:18:47,811 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - No validation split was found.\n",
- "2024-07-16 09:20:34,989 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best scale_pos_weight: 7\n",
- "2024-07-16 09:20:34,990 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best reg_lambda: 0\n",
- "2024-07-16 09:20:34,990 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best n_estimators: 100\n",
- "2024-07-16 09:20:34,991 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best max_depth: 5\n",
- "2024-07-16 09:20:34,991 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best learning_rate: 0.1\n",
- "2024-07-16 09:20:34,991 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best gamma: 0\n",
- "2024-07-16 09:20:34,992 \u001b[1;37mINFO\u001b[0m cyclops.models.wrappers.sk_model - Best colsample_bytree: 1\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
- " colsample_bylevel=None, colsample_bynode=None, colsample_bytree=1,\n",
- " early_stopping_rounds=None, enable_categorical=False,\n",
- " eval_metric='logloss', feature_types=None, gamma=0, gpu_id=None,\n",
- " grow_policy=None, importance_type=None,\n",
- " interaction_constraints=None, learning_rate=0.1, max_bin=None,\n",
- " max_cat_threshold=None, max_cat_to_onehot=None,\n",
- " max_delta_step=None, max_depth=5, max_leaves=None,\n",
- " min_child_weight=3, missing=nan, monotone_constraints=None,\n",
- " n_estimators=100, n_jobs=None, num_parallel_tree=None,\n",
- " predictor=None, random_state=123, ...) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. XGBClassifieriFitted XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
- " colsample_bylevel=None, colsample_bynode=None, colsample_bytree=1,\n",
- " early_stopping_rounds=None, enable_categorical=False,\n",
- " eval_metric='logloss', feature_types=None, gamma=0, gpu_id=None,\n",
- " grow_policy=None, importance_type=None,\n",
- " interaction_constraints=None, learning_rate=0.1, max_bin=None,\n",
- " max_cat_threshold=None, max_cat_to_onehot=None,\n",
- " max_delta_step=None, max_depth=5, max_leaves=None,\n",
- " min_child_weight=3, missing=nan, monotone_constraints=None,\n",
- " n_estimators=100, n_jobs=None, num_parallel_tree=None,\n",
- " predictor=None, random_state=123, ...) "
- ],
- "text/plain": [
- "XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
- " colsample_bylevel=None, colsample_bynode=None, colsample_bytree=1,\n",
- " early_stopping_rounds=None, enable_categorical=False,\n",
- " eval_metric='logloss', feature_types=None, gamma=0, gpu_id=None,\n",
- " grow_policy=None, importance_type=None,\n",
- " interaction_constraints=None, learning_rate=0.1, max_bin=None,\n",
- " max_cat_threshold=None, max_cat_to_onehot=None,\n",
- " max_delta_step=None, max_depth=5, max_leaves=None,\n",
- " min_child_weight=3, missing=nan, monotone_constraints=None,\n",
- " n_estimators=100, n_jobs=None, num_parallel_tree=None,\n",
- " predictor=None, random_state=123, ...)"
- ]
- },
- "execution_count": 22,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "best_model_params = {\n",
- " \"n_estimators\": [100, 250, 500],\n",
- " \"learning_rate\": [0.1, 0.01],\n",
- " \"max_depth\": [2, 5],\n",
- " \"reg_lambda\": [0, 1, 10],\n",
- " \"colsample_bytree\": [0.7, 0.8, 1],\n",
- " \"gamma\": [0, 1, 2, 10],\n",
- " \"method\": \"random\",\n",
- " \"scale_pos_weight\": [int(class_ratio)],\n",
- "}\n",
- "dataset[\"train\"] = dataset[\"train\"].train_test_split(train_size=0.8, seed=RANDOM_SEED)\n",
- "\n",
- "readmission_prediction_task.train(\n",
- " dataset[\"train\"],\n",
- " model_name=model_name,\n",
- " transforms=preprocessor,\n",
- " best_model_params=best_model_params,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'objective': 'binary:logistic', 'use_label_encoder': None, 'base_score': None, 'booster': None, 'callbacks': None, 'colsample_bylevel': None, 'colsample_bynode': None, 'colsample_bytree': 1, 'early_stopping_rounds': None, 'enable_categorical': False, 'eval_metric': 'logloss', 'feature_types': None, 'gamma': 0, 'gpu_id': None, 'grow_policy': None, 'importance_type': None, 'interaction_constraints': None, 'learning_rate': 0.1, 'max_bin': None, 'max_cat_threshold': None, 'max_cat_to_onehot': None, 'max_delta_step': None, 'max_depth': 5, 'max_leaves': None, 'min_child_weight': 3, 'missing': nan, 'monotone_constraints': None, 'n_estimators': 100, 'n_jobs': None, 'num_parallel_tree': None, 'predictor': None, 'random_state': 123, 'reg_alpha': None, 'reg_lambda': 0, 'sampling_method': None, 'scale_pos_weight': 7, 'subsample': None, 'tree_method': None, 'validate_parameters': None, 'verbosity': None, 'seed': 123}\n"
- ]
- }
- ],
- "source": [
- "model_params = readmission_prediction_task.list_models_params()[model_name]\n",
- "print(model_params)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Prediction\n",
- "\n",
- "The prediction output can be either the whole Hugging Face dataset with the prediction columns added to it or the single column containing the predicted values."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "# y_pred = readmission_prediction_task.predict(\n",
- "# dataset[\"test\"],\n",
- "# model_name=model_name,\n",
- "# transforms=preprocessor,\n",
- "# proba=True,\n",
- "# only_predictions=True,\n",
- "# )\n",
- "# prediction_df = pd.DataFrame(\n",
- "# {\n",
- "# \"y_prob\": [y_pred_i[1] for y_pred_i in y_pred],\n",
- "# \"y_true\": dataset[\"test\"][\"outcome\"],\n",
- "# }\n",
- "# )"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Evaluation\n",
- "\n",
- "Evaluation is done using various evaluation metrics that provide different perspectives on the model's predictive abilities i.e. standard performance metrics and fairness metrics.\n",
- "\n",
- "The standard performance metrics can be created using the `MetricDict` object."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "metric_names = [\n",
- " \"binary_accuracy\",\n",
- " \"binary_precision\",\n",
- " \"binary_recall\",\n",
- " \"binary_f1_score\",\n",
- " \"binary_auroc\",\n",
- " \"binary_average_precision\",\n",
- " \"binary_roc_curve\",\n",
- " \"binary_precision_recall_curve\",\n",
- "]\n",
- "metrics = [\n",
- " create_metric(metric_name, experimental=True) for metric_name in metric_names\n",
- "]\n",
- "metric_collection = MetricDict(metrics)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "specificity = create_metric(metric_name=\"binary_specificity\", experimental=True)\n",
- "sensitivity = create_metric(metric_name=\"binary_sensitivity\", experimental=True)\n",
- "\n",
- "fpr = -specificity + 1\n",
- "fnr = -sensitivity + 1\n",
- "\n",
- "ber = (fpr + fnr) / 2\n",
- "\n",
- "fairness_metric_collection = MetricDict(\n",
- " {\n",
- " \"Sensitivity\": sensitivity,\n",
- " \"Specificity\": specificity,\n",
- " \"BER\": ber,\n",
- " },\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The evaluate methods outputs the evaluation results and the Hugging Face dataset with the predictions added to it."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {},
- "outputs": [],
- "source": [
- "tester = Detectron(X_s=dataset[\"train\"],\n",
- " base_model=readmission_prediction_task.models['xgb_classifier'],\n",
- " feature_column=features_list,\n",
- " transforms=preprocessor,\n",
- " splits_mapping={\"train\": \"train\", \"test\": \"test\"},\n",
- " sample_size=250,\n",
- " num_runs=5,\n",
- " ensemble_size=5,\n",
- " task=\"binary\",\n",
- " save_dir=\"detectron\",\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0.916030534351145\n"
- ]
- }
- ],
- "source": [
- "# get model health on all test data\n",
- "results = tester.predict(X_t = DatasetDict({\"train\": dataset[\"train\"][\"train\"], \"test\": dataset[\"test\"]}))\n",
- "print(results[\"model_health\"])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "metadata": {},
- "outputs": [],
- "source": [
- "# split test data into 20 bins\n",
- "test_data = dataset[\"test\"]\n",
- "test_data_list = []\n",
- "\n",
- "indices = np.arange(0, len(test_data))\n",
- "\n",
- "bins = np.array_split(indices, 20)\n",
- "\n",
- "for bin in bins:\n",
- " test_data_list.append(test_data.select(bin))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "metadata": {},
- "outputs": [],
- "source": [
- "# get model health on all test data bins\n",
- "model_health = []\n",
- "for data in test_data_list:\n",
- " results = tester.predict(X_t=DatasetDict({\"train\": dataset[\"train\"][\"train\"], \"test\": data}))\n",
- " model_health.append(results[\"model_health\"])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.plotly.v1+json": {
- "config": {
- "plotlyServerURL": "https://plot.ly"
- },
- "data": [
- {
- "hovertemplate": "bin=%{x} model_health=%{y} ",
- "legendgroup": "",
- "line": {
- "color": "#636efa",
- "dash": "solid"
- },
- "marker": {
- "symbol": "circle"
- },
- "mode": "lines",
- "name": "",
- "orientation": "v",
- "showlegend": false,
- "type": "scatter",
- "x": [
- 0,
- 1,
- 2,
- 3,
- 4,
- 5,
- 6,
- 7,
- 8,
- 9,
- 10,
- 11,
- 12,
- 13,
- 14,
- 15,
- 16,
- 17,
- 18,
- 19
- ],
- "xaxis": "x",
- "y": [
- 1,
- 1,
- 0.9541984732824428,
- 1,
- 1,
- 0.9923664122137404,
- 0.8015267175572519,
- 1,
- 0.5725190839694657,
- 0.6870229007633588,
- 1,
- 0.8015267175572519,
- 1,
- 0.9923664122137404,
- 0.916030534351145,
- 0.7633587786259542,
- 0.9923664122137404,
- 0.8015267175572519,
- 1,
- 1
- ],
- "yaxis": "y"
- }
- ],
- "layout": {
- "legend": {
- "tracegroupgap": 0
- },
- "template": {
- "data": {
- "bar": [
- {
- "error_x": {
- "color": "#2a3f5f"
- },
- "error_y": {
- "color": "#2a3f5f"
- },
- "marker": {
- "line": {
- "color": "#E5ECF6",
- "width": 0.5
- },
- "pattern": {
- "fillmode": "overlay",
- "size": 10,
- "solidity": 0.2
- }
- },
- "type": "bar"
- }
- ],
- "barpolar": [
- {
- "marker": {
- "line": {
- "color": "#E5ECF6",
- "width": 0.5
- },
- "pattern": {
- "fillmode": "overlay",
- "size": 10,
- "solidity": 0.2
- }
- },
- "type": "barpolar"
- }
- ],
- "carpet": [
- {
- "aaxis": {
- "endlinecolor": "#2a3f5f",
- "gridcolor": "white",
- "linecolor": "white",
- "minorgridcolor": "white",
- "startlinecolor": "#2a3f5f"
- },
- "baxis": {
- "endlinecolor": "#2a3f5f",
- "gridcolor": "white",
- "linecolor": "white",
- "minorgridcolor": "white",
- "startlinecolor": "#2a3f5f"
- },
- "type": "carpet"
- }
- ],
- "choropleth": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "type": "choropleth"
- }
- ],
- "contour": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "colorscale": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "type": "contour"
- }
- ],
- "contourcarpet": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "type": "contourcarpet"
- }
- ],
- "heatmap": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "colorscale": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "type": "heatmap"
- }
- ],
- "heatmapgl": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "colorscale": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "type": "heatmapgl"
- }
- ],
- "histogram": [
- {
- "marker": {
- "pattern": {
- "fillmode": "overlay",
- "size": 10,
- "solidity": 0.2
- }
- },
- "type": "histogram"
- }
- ],
- "histogram2d": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "colorscale": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "type": "histogram2d"
- }
- ],
- "histogram2dcontour": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "colorscale": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "type": "histogram2dcontour"
- }
- ],
- "mesh3d": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "type": "mesh3d"
- }
- ],
- "parcoords": [
- {
- "line": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "parcoords"
- }
- ],
- "pie": [
- {
- "automargin": true,
- "type": "pie"
- }
- ],
- "scatter": [
- {
- "fillpattern": {
- "fillmode": "overlay",
- "size": 10,
- "solidity": 0.2
- },
- "type": "scatter"
- }
- ],
- "scatter3d": [
- {
- "line": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scatter3d"
- }
- ],
- "scattercarpet": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scattercarpet"
- }
- ],
- "scattergeo": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scattergeo"
- }
- ],
- "scattergl": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scattergl"
- }
- ],
- "scattermapbox": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scattermapbox"
- }
- ],
- "scatterpolar": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scatterpolar"
- }
- ],
- "scatterpolargl": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scatterpolargl"
- }
- ],
- "scatterternary": [
- {
- "marker": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "type": "scatterternary"
- }
- ],
- "surface": [
- {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- },
- "colorscale": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "type": "surface"
- }
- ],
- "table": [
- {
- "cells": {
- "fill": {
- "color": "#EBF0F8"
- },
- "line": {
- "color": "white"
- }
- },
- "header": {
- "fill": {
- "color": "#C8D4E3"
- },
- "line": {
- "color": "white"
- }
- },
- "type": "table"
- }
- ]
- },
- "layout": {
- "annotationdefaults": {
- "arrowcolor": "#2a3f5f",
- "arrowhead": 0,
- "arrowwidth": 1
- },
- "autotypenumbers": "strict",
- "coloraxis": {
- "colorbar": {
- "outlinewidth": 0,
- "ticks": ""
- }
- },
- "colorscale": {
- "diverging": [
- [
- 0,
- "#8e0152"
- ],
- [
- 0.1,
- "#c51b7d"
- ],
- [
- 0.2,
- "#de77ae"
- ],
- [
- 0.3,
- "#f1b6da"
- ],
- [
- 0.4,
- "#fde0ef"
- ],
- [
- 0.5,
- "#f7f7f7"
- ],
- [
- 0.6,
- "#e6f5d0"
- ],
- [
- 0.7,
- "#b8e186"
- ],
- [
- 0.8,
- "#7fbc41"
- ],
- [
- 0.9,
- "#4d9221"
- ],
- [
- 1,
- "#276419"
- ]
- ],
- "sequential": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ],
- "sequentialminus": [
- [
- 0,
- "#0d0887"
- ],
- [
- 0.1111111111111111,
- "#46039f"
- ],
- [
- 0.2222222222222222,
- "#7201a8"
- ],
- [
- 0.3333333333333333,
- "#9c179e"
- ],
- [
- 0.4444444444444444,
- "#bd3786"
- ],
- [
- 0.5555555555555556,
- "#d8576b"
- ],
- [
- 0.6666666666666666,
- "#ed7953"
- ],
- [
- 0.7777777777777778,
- "#fb9f3a"
- ],
- [
- 0.8888888888888888,
- "#fdca26"
- ],
- [
- 1,
- "#f0f921"
- ]
- ]
- },
- "colorway": [
- "#636efa",
- "#EF553B",
- "#00cc96",
- "#ab63fa",
- "#FFA15A",
- "#19d3f3",
- "#FF6692",
- "#B6E880",
- "#FF97FF",
- "#FECB52"
- ],
- "font": {
- "color": "#2a3f5f"
- },
- "geo": {
- "bgcolor": "white",
- "lakecolor": "white",
- "landcolor": "#E5ECF6",
- "showlakes": true,
- "showland": true,
- "subunitcolor": "white"
- },
- "hoverlabel": {
- "align": "left"
- },
- "hovermode": "closest",
- "mapbox": {
- "style": "light"
- },
- "paper_bgcolor": "white",
- "plot_bgcolor": "#E5ECF6",
- "polar": {
- "angularaxis": {
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": ""
- },
- "bgcolor": "#E5ECF6",
- "radialaxis": {
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": ""
- }
- },
- "scene": {
- "xaxis": {
- "backgroundcolor": "#E5ECF6",
- "gridcolor": "white",
- "gridwidth": 2,
- "linecolor": "white",
- "showbackground": true,
- "ticks": "",
- "zerolinecolor": "white"
- },
- "yaxis": {
- "backgroundcolor": "#E5ECF6",
- "gridcolor": "white",
- "gridwidth": 2,
- "linecolor": "white",
- "showbackground": true,
- "ticks": "",
- "zerolinecolor": "white"
- },
- "zaxis": {
- "backgroundcolor": "#E5ECF6",
- "gridcolor": "white",
- "gridwidth": 2,
- "linecolor": "white",
- "showbackground": true,
- "ticks": "",
- "zerolinecolor": "white"
- }
- },
- "shapedefaults": {
- "line": {
- "color": "#2a3f5f"
- }
- },
- "ternary": {
- "aaxis": {
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": ""
- },
- "baxis": {
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": ""
- },
- "bgcolor": "#E5ECF6",
- "caxis": {
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": ""
- }
- },
- "title": {
- "x": 0.05
- },
- "xaxis": {
- "automargin": true,
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": "",
- "title": {
- "standoff": 15
- },
- "zerolinecolor": "white",
- "zerolinewidth": 2
- },
- "yaxis": {
- "automargin": true,
- "gridcolor": "white",
- "linecolor": "white",
- "ticks": "",
- "title": {
- "standoff": 15
- },
- "zerolinecolor": "white",
- "zerolinewidth": 2
- }
- }
- },
- "title": {
- "text": "Model Health"
- },
- "xaxis": {
- "anchor": "y",
- "domain": [
- 0,
- 1
- ],
- "title": {
- "text": "bin"
- }
- },
- "yaxis": {
- "anchor": "x",
- "domain": [
- 0,
- 1
- ],
- "title": {
- "text": "model_health"
- }
- }
- }
- }
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "# use plotly to visualize the model health over the bins\n",
- "model_health_df = pd.DataFrame(model_health, columns=[\"model_health\"])\n",
- "\n",
- "model_health_df[\"bin\"] = np.arange(0, len(model_health_df))\n",
- "\n",
- "fig = px.line(model_health_df, x=\"bin\", y=\"model_health\", title=\"Model Health\")\n",
- "fig.show()"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.7"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}