diff --git a/use_cases/mortality_prediction.ipynb b/use_cases/mortality_prediction.ipynb deleted file mode 100644 index 79da859b2..000000000 --- a/use_cases/mortality_prediction.ipynb +++ /dev/null @@ -1,696 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Mortality Prediction using Tabular Data\n", - "\n", - "This notebooks presents the use-case of predicting the risk of mortality in patients on Mimic-IV dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "import glob\n", - "\n", - "import dask.dataframe as dd\n", - "import pandas as pd\n", - "from datasets import Dataset\n", - "from datasets.features import ClassLabel\n", - "from sklearn.compose import ColumnTransformer\n", - "from sklearn.impute import SimpleImputer\n", - "from sklearn.pipeline import Pipeline\n", - "from sklearn.preprocessing import OneHotEncoder, StandardScaler\n", - "\n", - "from cyclops.data.slicer import SliceSpec\n", - "from cyclops.evaluate.metrics import MetricCollection, create_metric\n", - "from cyclops.models.catalog import create_model\n", - "from cyclops.process.aggregate import Aggregator\n", - "from cyclops.process.column_names import (\n", - " AGE,\n", - " ENCOUNTER_ID,\n", - " EVENT_NAME,\n", - " EVENT_TIMESTAMP,\n", - " EVENT_VALUE,\n", - " SEX,\n", - ")\n", - "from cyclops.process.constants import MEAN, NUMERIC, STRING\n", - "from cyclops.process.feature.feature import TabularFeatures\n", - "from cyclops.tasks.mortality_prediction import MortalityPredictionTask\n", - "from cyclops.utils.file import join, load_dataframe, process_dir_save_path" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Constants" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "DATASET = \"mimiciv\"\n", - "CONST_NAME = \"mortality_decompensation\"\n", - "\n", - "USECASE_ROOT_DIR = join(\n", - " \"/mnt/data\",\n", - " \"cyclops\",\n", - " \"use_cases\",\n", - " DATASET,\n", - " CONST_NAME,\n", - ")\n", - "DATA_DIR = process_dir_save_path(join(USECASE_ROOT_DIR, \"./data\"))\n", - "ENCOUNTERS_FILE = join(DATA_DIR, \"encounters.parquet\")\n", - "EVENTS_DIR = join(DATA_DIR, \"./1-cleaned\")\n", - "\n", - "OUTCOME_DEATH = \"outcome_death\"\n", - "TAB_FEATURES = [\n", - " AGE,\n", - " SEX,\n", - " \"admission_type\",\n", - " \"admission_location\",\n", - "]\n", - "\n", - "SPLIT_FRACTIONS = [0.8, 0.1, 0.1]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Features\n", - "A list of selected events along side age, sex, admission type and admission location are used as the features." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "selected_events = [\n", - " \"routine vital signs - arterial blood pressure diastolic\",\n", - " \"routine vital signs - arterial blood pressure mean\",\n", - " \"routine vital signs - arterial blood pressure systolic\",\n", - " \"routine vital signs - heart rate\",\n", - " \"routine vital signs - non invasive blood pressure diastolic\",\n", - " \"routine vital signs - non invasive blood pressure mean\",\n", - " \"routine vital signs - non invasive blood pressure systolic\",\n", - " \"labs - arterial base excess\",\n", - " \"labs - arterial co2 pressure\",\n", - " \"labs - arterial o2 pressure\",\n", - " \"labs - bun\",\n", - " \"labs - calcium non-ionized\",\n", - " \"labs - chloride\",\n", - " \"labs - creatinine\",\n", - " \"labs - glucose\",\n", - " \"labs - glucose finger stick\",\n", - " \"labs - hco3\",\n", - " \"labs - hematocrit\",\n", - " \"labs - hemoglobin\",\n", - " \"labs - inr\",\n", - " \"labs - magnesium\",\n", - " \"labs - ph\",\n", - " \"labs - phosphorous\",\n", - " \"labs - platelet count\",\n", - " \"labs - potassium\",\n", - " \"labs - prothrombin time\",\n", - " \"labs - ptt\",\n", - " \"labs - sodium\",\n", - " \"labs - tco2 arterial\",\n", - " \"labs - wbc\",\n", - " \"respiratory - apnea interval\",\n", - " \"respiratory - fspn high\",\n", - " \"respiratory - inspired o2 fraction\",\n", - " \"respiratory - mean airway pressure\",\n", - " \"respiratory - minute volume\",\n", - " \"respiratory - o2 flow\",\n", - " \"respiratory - o2 saturation pulseoxymetry\",\n", - " \"respiratory - paw high\",\n", - " \"respiratory - peak insp. pressure\",\n", - " \"respiratory - peep set\",\n", - " \"respiratory - respiratory rate\",\n", - " \"respiratory - tidal volume\",\n", - " \"respiratory - ventilator mode\",\n", - " \"respiratory - vti high\",\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "features_list = sorted(selected_events + TAB_FEATURES)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Data Loading\n", - "\n", - "The tabular features are extracted from the encouter records and the events are extracted from the saved parquet files." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "encounters_df = load_dataframe(ENCOUNTERS_FILE)\n", - "encounters_df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "parquet_files = list(glob.glob(join(EVENTS_DIR, \"*.parquet\")))\n", - "ddf = dd.read_parquet(path=parquet_files)\n", - "ddf.npartitions" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Data Preprocessing" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In the preprocessing step, the values of selected events are aggregated. After creating a dataframe that contains all the features, Scikit-learn transformations are used to preprocess the categorical and numerical data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Instantiate the aggregator\n", - "aggregator = Aggregator(\n", - " aggfuncs={EVENT_VALUE: MEAN},\n", - " timestamp_col=EVENT_TIMESTAMP,\n", - " time_by=ENCOUNTER_ID,\n", - " agg_by=[ENCOUNTER_ID, EVENT_NAME],\n", - " timestep_size=24,\n", - " window_duration=72,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# A function to filter data based on the time events were recorded.\n", - "# Here, the deathtime must be at least 24 hours after the last event.\n", - "\n", - "\n", - "def stop_bound(data: pd.DataFrame):\n", - " return data[\n", - " (data[\"deathtime\"] > (data[\"stop\"] + pd.Timedelta(hours=24)))\n", - " | (data[OUTCOME_DEATH] != 1)\n", - " ]\n", - "\n", - "\n", - "# Perform aggregation on data partitions\n", - "def process_partition(partition):\n", - " invalid = (partition[\"hospital_expire_flag\"] == 1) & (partition[\"deathtime\"].isna())\n", - " partition = partition[~invalid]\n", - " partition[OUTCOME_DEATH] = partition[\"hospital_expire_flag\"] == 1\n", - " partition[OUTCOME_DEATH] = partition[OUTCOME_DEATH].astype(int)\n", - " # Aggregate events values over the window duration\n", - " aggregated = aggregator.aggregate_values(partition, stop_bound_func=stop_bound)\n", - " # Convert the events rows into feature columns\n", - " aggregated = aggregated.reset_index()\n", - " return aggregated.pivot(\n", - " index=ENCOUNTER_ID,\n", - " columns=EVENT_NAME,\n", - " values=EVENT_VALUE,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# run aggregation on data partitions and create the final dataframe\n", - "selected_events = sorted(selected_events)\n", - "ddf = ddf.loc[ddf[EVENT_NAME].isin(selected_events)]\n", - "meta_df = pd.DataFrame(columns=selected_events)\n", - "ddf = ddf.map_partitions(process_partition, meta=meta_df)\n", - "df = ddf.compute()\n", - "df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# merge events data with encounters data which can be used for slicing during evaluation\n", - "df = df.merge(encounters_df, on=ENCOUNTER_ID).set_index(ENCOUNTER_ID)\n", - "df[OUTCOME_DEATH] = df[OUTCOME_DEATH].astype(int)\n", - "df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# the data is heavily unbalanced\n", - "class_counts = df[OUTCOME_DEATH].value_counts()\n", - "class_ratio = class_counts[0] / class_counts[1]\n", - "class_ratio" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# use TabularFeatures object to detect the feature types\n", - "tab_features = TabularFeatures(\n", - " data=df.reset_index(),\n", - " features=features_list,\n", - " by=ENCOUNTER_ID,\n", - " targets=OUTCOME_DEATH,\n", - " force_types={\n", - " SEX: STRING,\n", - " \"admission_location\": STRING,\n", - " \"admission_type\": STRING,\n", - " },\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "categorical_features = sorted(tab_features.features_by_type(STRING))\n", - "categorical_features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "numeric_features = sorted((tab_features.features_by_type(NUMERIC)))\n", - "numeric_features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# extract the indices of the features from a subset of the dataframe \\\n", - "# that is used for modeling\n", - "categorical_indices = [\n", - " df[features_list].columns.get_loc(column) for column in categorical_features\n", - "]\n", - "numeric_indices = [\n", - " df[features_list].columns.get_loc(column) for column in numeric_features\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# define the preprocessor\n", - "numeric_transformer = Pipeline(\n", - " steps=[(\"imputer\", SimpleImputer(strategy=\"mean\")), (\"scaler\", StandardScaler())],\n", - ")\n", - "categorical_transformer = OneHotEncoder(handle_unknown=\"ignore\")\n", - "preprocessor = ColumnTransformer(\n", - " transformers=[\n", - " (\"num\", numeric_transformer, numeric_indices),\n", - " (\"cat\", categorical_transformer, categorical_indices),\n", - " ],\n", - " remainder=\"passthrough\",\n", - ")\n", - "# fit the preprocessor for later use\n", - "preprocessor = preprocessor.fit(df[features_list])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dataset Creation\n", - "\n", - "The dataframe is converted to Hugging Face Dataset which is necessary for evaluation and optional for training and inference." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = Dataset.from_pandas(df)\n", - "dataset.cleanup_cache_files()\n", - "dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# split the data to train and test subsets while preserving \\\n", - "# the same class ratio for both subsets\n", - "dataset = dataset.cast_column(OUTCOME_DEATH, ClassLabel(num_classes=2))\n", - "dataset = dataset.train_test_split(\n", - " train_size=SPLIT_FRACTIONS[0],\n", - " stratify_by_column=OUTCOME_DEATH,\n", - " seed=42,\n", - ")\n", - "dataset[\"train\"]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model Creation\n", - "\n", - "The CyclOps Model API is used to create models using estimators from the Scikit-learn package. The configuration of the model is based on the corresponding config files, which include the necessary parameters for instantiating the Scikit-learn estimators, as well as optional parameters for hyperparameter search." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "mlp_name = \"mlp\"\n", - "mlp_model = create_model(mlp_name)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "xgb_name = \"xgb_classifier\"\n", - "\n", - "# handle imbalanced data by setting scale_pos_weight\n", - "xgb_model = create_model(xgb_name, scale_pos_weight=class_ratio)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Mortality Prediction Task\n", - "\n", - "The CyclOps Task API is used to create a Mortality Prediction Task based on the available models and dataset. The task can contain multiple models that can be trained and used for prediction individually. This is particularly useful when comparing the performance of multiple models during the evaluation step." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "mortality_task = MortalityPredictionTask(\n", - " {xgb_name: xgb_model},\n", - " task_features=features_list,\n", - " task_target=OUTCOME_DEATH,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "mortality_task.add_model(mlp_name)\n", - "mortality_task.list_models()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "mortality_task.list_models_params()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training\n", - "\n", - "The `train` method can be used for Hugging Face datasets, numpy arrays or dataframes (containing only the relevant columns). \n", - "If the input data is a Hugging Face Dataset and not preprocessed already, a `ColumnTransformer` object can be passed as transforms." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "mortality_task.train(\n", - " dataset[\"train\"],\n", - " model_name=xgb_name,\n", - " transforms=preprocessor,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "mortality_task.train(dataset[\"train\"], model_name=mlp_name, transforms=preprocessor)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Prediction\n", - "\n", - "In the prediction phase, the task object allows for a variety of data inputs, including numpy arrays, pandas dataframes, and Hugging Face Datasets.\n", - "\n", - "When using a Hugging Face dataset as the input, it is possible to obtain the entire dataset with the added prediction column as the output of the predict method." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# run prediction\n", - "ds_with_mlp_preds = mortality_task.predict(\n", - " dataset,\n", - " model_name=xgb_name,\n", - " prediction_column_prefix=\"preds\",\n", - " transforms=preprocessor,\n", - " only_predictions=False,\n", - " proba=True,\n", - " splits_mapping={\"test\": \"test\"},\n", - ")\n", - "ds_with_mlp_preds.to_pandas()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Evaluation\n", - "\n", - "Evaluation is typically performed on a Hugging Face dataset. To evaluate the models, you can also provide a slice specification to see how well they perform for different slices of data based on the feature values.\n", - "\n", - "In addition to the dataset and slice specification, you need to specify the desired evaluation metrics. This can be done by providing a MetricCollection object, a list of metrics, or metric names.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "spec_list = [\n", - " {\"sex\": {\"value\": \"M\"}}, # feature value is M\n", - " {\n", - " \"age\": {\n", - " \"min_value\": 18,\n", - " \"max_value\": 65,\n", - " \"min_inclusive\": True,\n", - " \"max_inclusive\": False,\n", - " },\n", - " }, # feature value is between 18 and 65, inclusive of 18, exclusive of 65\n", - " {\n", - " \"admission_type\": {\"value\": [\"EW EMER.\", \"DIRECT EMER.\", \"URGENT\"]},\n", - " }, # feature value is in the list\n", - " {\n", - " \"admission_location\": {\n", - " \"value\": [\"PHYSICIAN REFERRAL\", \"CLINIC REFERRAL\", \"WALK-IN/SELF REFERRAL\"],\n", - " \"negate\": True,\n", - " },\n", - " }, # feature value is NOT in the list\n", - " {\n", - " \"dod\": {\"max_value\": \"2019-12-01\", \"keep_nulls\": False},\n", - " }, # possibly before COVID-19\n", - " {\n", - " \"dod\": {\"max_value\": \"2019-12-01\", \"negate\": True, \"keep_nulls\": False},\n", - " }, # possibly during COVID-19\n", - " {\"admit_timestamp\": {\"month\": [6, 7, 8, 9], \"keep_nulls\": False}},\n", - " {\n", - " \"sex\": {\"value\": \"F\"},\n", - " \"race\": {\n", - " \"value\": [\n", - " \"BLACK/AFRICAN AMERICAN\",\n", - " \"BLACK/CARIBBEAN ISLAND\",\n", - " \"BLACK/CAPE VERDEAN\",\n", - " \"BLACK/AFRICAN\",\n", - " ],\n", - " },\n", - " \"age\": {\"min_value\": 25, \"max_value\": 40},\n", - " }, # compound slice\n", - "]\n", - "\n", - "\n", - "# create the slice functions\n", - "slice_spec = SliceSpec(spec_list)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# define metrics\n", - "metric_names = [\"accuracy\", \"precision\", \"recall\", \"f1_score\", \"auroc\"]\n", - "metrics = [create_metric(metric_name, task=\"binary\") for metric_name in metric_names]\n", - "metric_collection = MetricCollection(metrics)\n", - "\n", - "# run evaluation\n", - "results, dataset_with_preds = mortality_task.evaluate(\n", - " dataset[\"test\"],\n", - " metric_collection,\n", - " model_names=xgb_name,\n", - " transforms=preprocessor,\n", - " prediction_column_prefix=\"preds\",\n", - " slice_spec=slice_spec,\n", - " batch_size=200,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "dataset_with_preds" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -}