From 4b42ac48cabe47104f6204e4df0d67ee787e3d77 Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Wed, 22 Nov 2023 11:31:24 -0500 Subject: [PATCH] Clean up use cases, add new tabular prediction on MIMICIV (#510) * Clean up use cases, add new tabular prediction on MIMICIV * Bring example use cases first * Remove decompensation notebook, its not ready --- cyclops/process/aggregate.py | 21 +- docs/source/tutorials.rst | 2 +- .../mimiciv/mortality_prediction.ipynb | 1279 ++++++++++++++ .../tutorials/synthea/los_prediction.ipynb | 40 +- docs/source/tutorials_use_cases.rst | 9 + poetry.lock | 40 +- pyproject.toml | 1 - use_cases/__init__.py | 1 - use_cases/common/util.py | 25 - .../mimiciv/mortality_decompensation.ipynb | 259 --- use_cases/data_processors/mimiciv.py | 1476 ----------------- use_cases/data_processors/process.py | 56 - .../mimiciv/mortality_decompensation.ipynb | 433 ----- .../mortality_decompensation/constants.py | 141 -- .../mortality_decompensation/constants_v1.py | 57 - use_cases/util.py | 154 -- 16 files changed, 1324 insertions(+), 2670 deletions(-) create mode 100644 docs/source/tutorials/mimiciv/mortality_prediction.ipynb delete mode 100644 use_cases/__init__.py delete mode 100644 use_cases/common/util.py delete mode 100644 use_cases/data_collectors/mimiciv/mortality_decompensation.ipynb delete mode 100644 use_cases/data_processors/mimiciv.py delete mode 100644 use_cases/data_processors/process.py delete mode 100644 use_cases/examples/mimiciv/mortality_decompensation.ipynb delete mode 100644 use_cases/params/mimiciv/mortality_decompensation/constants.py delete mode 100644 use_cases/params/mimiciv/mortality_decompensation/constants_v1.py delete mode 100644 use_cases/util.py diff --git a/cyclops/process/aggregate.py b/cyclops/process/aggregate.py index 94d9438ff..1bba696f0 100644 --- a/cyclops/process/aggregate.py +++ b/cyclops/process/aggregate.py @@ -66,7 +66,7 @@ def __init__( timestamp_col: str, time_by: Union[str, List[str]], agg_by: Union[str, List[str]], - timestep_size: int, + timestep_size: Optional[int] = None, window_duration: Optional[int] = None, imputer: Optional[AggregatedImputer] = None, agg_meta_for: Optional[List[str]] = None, @@ -78,9 +78,9 @@ def __init__( self.timestamp_col = timestamp_col self.time_by = to_list(time_by) self.agg_by = to_list(agg_by) + self.agg_meta_for = to_list_optional(agg_meta_for) self.timestep_size = timestep_size self.window_duration = window_duration - self.agg_meta_for = to_list_optional(agg_meta_for) self.window_times = pd.DataFrame() # Calculated when given the data self.imputer = imputer # Parameter checking @@ -90,8 +90,8 @@ def __init__( raise ValueError( "Cannot compute meta for a column not being aggregated.", ) - if self.window_duration is not None: - divided = self.window_duration / self.timestep_size + if window_duration is not None and timestep_size is not None: + divided = window_duration / timestep_size if divided != int(divided): raise ValueError("Window duration be divisible by bucket size.") @@ -568,6 +568,10 @@ def vectorize(self, aggregated: pd.DataFrame) -> Vectorized: raise NotImplementedError( "Cannot currently vectorize data aggregated with no window duration.", ) + if self.timestep_size is None: + raise NotImplementedError( + "Cannot currently vectorize data aggregated with no timestep size.", + ) num_timesteps = int(self.window_duration / self.timestep_size) # Parameter checking has_columns(aggregated, list(self.aggfuncs.keys()), raise_error=True) @@ -605,8 +609,6 @@ def aggregate_values( data: pd.DataFrame, window_start_time: Optional[pd.DataFrame] = None, window_stop_time: Optional[pd.DataFrame] = None, - start_bound_func: Optional[Callable[[pd.Series], pd.Series]] = None, - stop_bound_func: Optional[Callable[[pd.Series], pd.Series]] = None, ) -> pd.DataFrame: """Aggregate temporal values. @@ -622,10 +624,6 @@ def aggregate_values( window_stop_time: pd.DataFrame, optional An optionally provided window stop time. This cannot be provided if window_duration was set. - start_bound_func : Optional[Callable[[pd.Series], pd.Series]], optional - A function to bound the start timestamp values, by default None - stop_bound_func : Optional[Callable[[pd.Series], pd.Series]], optional - A function to bound the start timestamp values, by default None Returns ------- @@ -648,9 +646,6 @@ def aggregate_values( ) # Restrict the data according to the start/stop data = self._restrict_by_timestamp(data) - # Filter the data based on bounds on start/stop - data = start_bound_func(data) if start_bound_func else data - data = stop_bound_func(data) if stop_bound_func else data grouped = data.groupby(self.agg_by, sort=False) return grouped.agg(self.aggfuncs) diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst index 490aaaf11..5a81267aa 100644 --- a/docs/source/tutorials.rst +++ b/docs/source/tutorials.rst @@ -3,5 +3,5 @@ Tutorials .. toctree:: - tutorials_monitor tutorials_use_cases + tutorials_monitor diff --git a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb new file mode 100644 index 000000000..fcabe1b54 --- /dev/null +++ b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb @@ -0,0 +1,1279 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Mortality Prediction\n", + "\n", + "This notebook showcases mortality prediction on the [MIMICIV](https://physionet.org/content/mimiciv/2.0) dataset using CyclOps. The task is formulated as a binary classification task, whether the patient will die within the next N days. The prediction can be made after M number of days after admission. For example, if N = 14 and M = 1, we are predicting risk of patient mortality within 14 days of admission after considering 24 hours of data after admission." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import Libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Mortality Prediction.\"\"\"\n", + "\n", + "import copy\n", + "import shutil\n", + "from datetime import date\n", + "\n", + "import cycquery.ops as qo\n", + "import numpy as np\n", + "import pandas as pd\n", + "import plotly.express as px\n", + "import plotly.graph_objects as go\n", + "from cycquery import MIMICIVQuerier\n", + "from datasets import Dataset\n", + "from datasets.features import ClassLabel\n", + "from sklearn.compose import ColumnTransformer\n", + "from sklearn.impute import SimpleImputer\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.preprocessing import MinMaxScaler, OneHotEncoder\n", + "\n", + "from cyclops.data.slicer import SliceSpec\n", + "from cyclops.evaluate.fairness import FairnessConfig # noqa: E402\n", + "from cyclops.evaluate.metrics import MetricCollection, create_metric\n", + "from cyclops.models.catalog import create_model\n", + "from cyclops.process.aggregate import RESTRICT_TIMESTAMP, Aggregator\n", + "from cyclops.process.clean import normalize_names\n", + "from cyclops.process.feature.feature import TabularFeatures\n", + "from cyclops.report import ModelCardReport\n", + "from cyclops.report.plot.classification import ClassificationPlotter\n", + "from cyclops.report.utils import flatten_results_dict\n", + "from cyclops.tasks import BinaryTabularClassificationTask\n", + "from cyclops.utils.common import add_years_approximate" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "CyclOps offers a package for documentation of the model through a model report. The `ModelCardReport` class is used to populate and generate the model report as an HTML file. The model report has the following sections:\n", + "\n", + "- Overview: Provides a high level overview of how the model is doing (a quick glance of important metrics), and how it is doing over time (performance over several metrics and subgroups over time).\n", + "- Datasets: High level statistics of the training data, including changes in distribution over time.\n", + "- Quantitative Analysis: This section contains additional detailed performance metrics of the model for different sets of the data and subpopulations.\n", + "- Fairness Analysis: This section contains the fairness metrics of the model.\n", + "- Model Details: This section contains descriptive metadata about the model such as the owners, version, license, etc.\n", + "- Model Parameters: This section contains the technical details of the model such as the model architecture, training parameters, etc.\n", + "- Considerations: This section contains descriptions of the considerations involved in developing and using the model such as the intended use, limitations, etc.\n", + "\n", + "We will use this to document the model development process as we go along and generate the model report at the end.\n", + "\n", + "`The model report tool is a work in progress and is subject to change.`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report = ModelCardReport()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Constants" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "M = 1\n", + "N = 14\n", + "NAN_THRESHOLD = 0.25\n", + "TRAIN_SIZE = 0.8\n", + "RANDOM_SEED = 12" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Querying & Processing\n", + "\n", + "### Compute mortality (labels)\n", + "\n", + "1. Get encounters\n", + "2. Filter out encounters less than M days\n", + "3. Set label = 1 for encounters where deathtime is within N days after admission\n", + "4. Get lab events\n", + "5. Aggregate them by computing mean, merge with encounter data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "querier = MIMICIVQuerier(\n", + " dbms=\"postgresql\",\n", + " port=5432,\n", + " host=\"localhost\",\n", + " database=\"mimiciv-2.0\",\n", + " user=\"postgres\",\n", + " password=\"pwd\",\n", + ")\n", + "\n", + "\n", + "def get_encounters():\n", + " \"\"\"Get encounters data.\"\"\"\n", + " patients = querier.patients()\n", + " encounters = querier.mimiciv_hosp.admissions()\n", + " drop_op = qo.Drop(\n", + " [\"insurance\", \"language\", \"marital_status\", \"edregtime\", \"edouttime\"],\n", + " )\n", + " encounters = encounters.ops(drop_op)\n", + " patient_encounters = patients.join(encounters, on=\"subject_id\")\n", + " patient_encounters = patient_encounters.run()\n", + " patient_encounters[\"age\"] = (\n", + " patient_encounters[\"admittime\"].dt.year\n", + " - patient_encounters[\"anchor_year\"]\n", + " + patient_encounters[\"anchor_age\"]\n", + " )\n", + " for col in [\"admittime\", \"dischtime\", \"deathtime\"]:\n", + " patient_encounters[col] = add_years_approximate(\n", + " patient_encounters[col],\n", + " patient_encounters[\"anchor_year_difference\"],\n", + " )\n", + "\n", + " return patient_encounters[\n", + " [\n", + " \"hadm_id\",\n", + " \"admittime\",\n", + " \"dischtime\",\n", + " \"deathtime\",\n", + " \"anchor_age\",\n", + " \"age\",\n", + " \"gender\",\n", + " \"anchor_year_difference\",\n", + " \"admission_location\",\n", + " \"hospital_expire_flag\",\n", + " ]\n", + " ]\n", + "\n", + "\n", + "def compute_mortality_outcome(patient_encounters):\n", + " \"\"\"Compute mortality outcome.\"\"\"\n", + " # Drop encounters ending in death which don't have a death timestamp\n", + " invalid = (patient_encounters[\"hospital_expire_flag\"] == 1) & (\n", + " patient_encounters[\"deathtime\"].isna()\n", + " )\n", + " patient_encounters = patient_encounters[~invalid]\n", + " print(f\"Encounters with death flag but no death timestamp: {invalid.sum()}\")\n", + " # Drop encounters which are shorter than M days\n", + " invalid = (\n", + " patient_encounters[\"dischtime\"] - patient_encounters[\"admittime\"]\n", + " ).dt.days < M\n", + " patient_encounters = patient_encounters[~invalid]\n", + " print(f\"Encounters shorter than {M} days: {invalid.sum()}\")\n", + " # Death timestamp is within (<=) N days of admission\n", + " valid = (\n", + " patient_encounters[\"deathtime\"] - patient_encounters[\"admittime\"]\n", + " ).dt.days <= N\n", + " print(f\"Encounters with death timestamp within {N} days: {valid.sum()}\")\n", + " # (Died in hospital) & (Death timestamp is defined)\n", + " patient_encounters[\"mortality_outcome\"] = 0\n", + " patient_encounters[\"mortality_outcome\"][valid] = 1\n", + " print(\n", + " f\"Encounters with mortality outcome for the model: {patient_encounters['mortality_outcome'].sum()}\",\n", + " )\n", + "\n", + " return patient_encounters\n", + "\n", + "\n", + "def get_labevents(patient_encounters):\n", + " \"\"\"Get labevents data.\"\"\"\n", + " labevents = querier.labevents().run(index_col=\"hadm_id\", batch_mode=True)\n", + "\n", + " def process_labevents(labevents, patient_encounters):\n", + " \"\"\"Process labevents before aggregation.\"\"\"\n", + " # Reverse deidentified dating\n", + " labevents = pd.merge(\n", + " patient_encounters[\n", + " [\n", + " \"hadm_id\",\n", + " \"anchor_year_difference\",\n", + " ]\n", + " ],\n", + " labevents,\n", + " on=\"hadm_id\",\n", + " )\n", + " labevents[\"charttime\"] = add_years_approximate(\n", + " labevents[\"charttime\"],\n", + " labevents[\"anchor_year_difference\"],\n", + " )\n", + " labevents = labevents.drop(\"anchor_year_difference\", axis=1)\n", + " # Pre-processing\n", + " labevents[\"label\"] = normalize_names(labevents[\"label\"])\n", + " labevents[\"category\"] = normalize_names(labevents[\"category\"])\n", + "\n", + " return labevents\n", + "\n", + " start_timestamps = (\n", + " patient_encounters[[\"hadm_id\", \"admittime\"]]\n", + " .set_index(\"hadm_id\")\n", + " .rename({\"admittime\": RESTRICT_TIMESTAMP}, axis=1)\n", + " )\n", + " mean_aggregator = Aggregator(\n", + " aggfuncs={\n", + " \"valuenum\": \"mean\",\n", + " },\n", + " window_duration=M * 24,\n", + " timestamp_col=\"charttime\",\n", + " time_by=\"hadm_id\",\n", + " agg_by=[\"hadm_id\", \"label\"],\n", + " )\n", + " means_df = pd.DataFrame()\n", + " for batch_num, labevents_batch in enumerate(labevents):\n", + " labevents_batch = process_labevents( # noqa: PLW2901\n", + " labevents_batch,\n", + " patient_encounters,\n", + " )\n", + " means = mean_aggregator.aggregate_values(\n", + " labevents_batch,\n", + " window_start_time=start_timestamps,\n", + " )\n", + " means = means.reset_index()\n", + " means = means.pivot(index=\"hadm_id\", columns=\"label\", values=\"valuenum\")\n", + " means = means.add_prefix(\"lab_\")\n", + " means = pd.merge(\n", + " patient_encounters[\n", + " [\n", + " \"hadm_id\",\n", + " \"mortality_outcome\",\n", + " \"age\",\n", + " \"gender\",\n", + " \"admission_location\",\n", + " ]\n", + " ],\n", + " means,\n", + " on=\"hadm_id\",\n", + " )\n", + " means_df = pd.concat([means_df, means])\n", + " if batch_num == 2:\n", + " break\n", + " print(\"Processing batch {}\".format(batch_num + 1))\n", + "\n", + " return means_df\n", + "\n", + "\n", + "def run_query():\n", + " \"\"\"Run query.\"\"\"\n", + " cohort = get_encounters()\n", + " cohort = compute_mortality_outcome(cohort)\n", + "\n", + " return get_labevents(cohort)\n", + "\n", + "\n", + "cohort = run_query()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Inspection and Preprocessing\n", + "\n", + "### Drop NaNs based on the `NAN_THRESHOLD`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "null_counts = cohort.isnull().sum()[cohort.isnull().sum() > 0]\n", + "fig = go.Figure(data=[go.Bar(x=null_counts.index, y=null_counts.values)])\n", + "\n", + "fig.update_layout(\n", + " title=\"Number of Null Values per Column\",\n", + " xaxis_title=\"Columns\",\n", + " yaxis_title=\"Number of Null Values\",\n", + " height=600,\n", + ")\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Add the figure to the report**\n", + "\n", + "We can use the log_plotly_figure method to add the figure to a section of the report. One can specify whether the figure should be interactive or not by setting the `interactive` parameter to `True` or `False` respectively. The default value is `True`. This\n", + "also affects the final size of the report. If the figure is interactive, the size of the report will be larger than if the figure is not interactive. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.log_plotly_figure(\n", + " fig=fig,\n", + " caption=\"Number of Null Values per Column\",\n", + " section_name=\"datasets\",\n", + " interactive=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "thresh_nan = int(NAN_THRESHOLD * len(cohort))\n", + "cohort = cohort.dropna(axis=1, thresh=thresh_nan)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Outcome distribution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cohort[\"mortality_outcome\"] = cohort[\"mortality_outcome\"].astype(\"int\")\n", + "fig = px.pie(cohort, names=\"mortality_outcome\")\n", + "fig.update_traces(textinfo=\"percent+label\")\n", + "fig.update_layout(title_text=\"Outcome Distribution\")\n", + "fig.update_traces(\n", + " hovertemplate=\"Outcome: %{label}
Count: \\\n", + " %{value}
Percent: %{percent}\",\n", + ")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Add the figure to the report**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.log_plotly_figure(\n", + " fig=fig,\n", + " caption=\"Outcome Distribution\",\n", + " section_name=\"datasets\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The data is heavily unbalanced.\n", + "class_counts = cohort[\"mortality_outcome\"].value_counts()\n", + "class_ratio = class_counts[0] / class_counts[1]\n", + "print(class_ratio)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Gender distribution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.pie(cohort, names=\"gender\")\n", + "fig.update_layout(\n", + " title=\"Gender Distribution\",\n", + ")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Add the figure to the report**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.log_plotly_figure(\n", + " fig=fig,\n", + " caption=\"Gender Distribution\",\n", + " section_name=\"datasets\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Age distribution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.histogram(cohort, x=\"age\")\n", + "fig.update_layout(\n", + " title=\"Age Distribution\",\n", + " xaxis_title=\"Age\",\n", + " yaxis_title=\"Count\",\n", + " bargap=0.2,\n", + ")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Add the figure to the report**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.log_plotly_figure(\n", + " fig=fig,\n", + " caption=\"Age Distribution\",\n", + " section_name=\"datasets\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Identifying feature types\n", + "\n", + "Cyclops `TabularFeatures` class helps to identify feature types, an essential step before preprocessing the data. Understanding feature types (numerical/categorical/binary) allows us to apply appropriate preprocessing steps for each type." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "features_list = set(cohort.columns.tolist()) - {\"hadm_id\", \"mortality_outcome\"}\n", + "features_list = sorted(features_list)\n", + "tab_features = TabularFeatures(\n", + " data=cohort.reset_index(),\n", + " features=features_list,\n", + " by=\"hadm_id\",\n", + " targets=\"mortality_outcome\",\n", + ")\n", + "print(tab_features.types)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Creating data preprocessors\n", + "\n", + "We create a data preprocessor using sklearn's ColumnTransformer. This helps in applying different preprocessing steps to different columns in the dataframe. For instance, binary features might be processed differently from numeric features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "numeric_transformer = Pipeline(\n", + " steps=[(\"imputer\", SimpleImputer(strategy=\"mean\")), (\"scaler\", MinMaxScaler())],\n", + ")\n", + "binary_transformer = Pipeline(\n", + " steps=[(\"imputer\", SimpleImputer(strategy=\"most_frequent\"))],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "numeric_features = sorted((tab_features.features_by_type(\"numeric\")))\n", + "numeric_indices = [\n", + " cohort[features_list].columns.get_loc(column) for column in numeric_features\n", + "]\n", + "print(numeric_features)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "binary_features = sorted(tab_features.features_by_type(\"binary\"))\n", + "ordinal_features = sorted(tab_features.features_by_type(\"ordinal\"))\n", + "binary_features.remove(\"mortality_outcome\")\n", + "binary_indices = [\n", + " cohort[features_list].columns.get_loc(column) for column in binary_features\n", + "]\n", + "ordinal_indices = [\n", + " cohort[features_list].columns.get_loc(column) for column in ordinal_features\n", + "]\n", + "print(binary_features, ordinal_features)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preprocessor = ColumnTransformer(\n", + " transformers=[\n", + " (\"num\", numeric_transformer, numeric_indices),\n", + " (\n", + " \"onehot\",\n", + " OneHotEncoder(handle_unknown=\"ignore\"),\n", + " binary_indices + ordinal_indices,\n", + " ),\n", + " ],\n", + " remainder=\"passthrough\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating Hugging Face Dataset\n", + "\n", + "We convert our processed Pandas dataframe into a Hugging Face dataset, a powerful and easy-to-use data format which is also compatible with CyclOps models and evaluator modules. The dataset is then split to train and test sets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cohort = cohort.drop(columns=[\"hadm_id\"])\n", + "dataset = Dataset.from_pandas(cohort)\n", + "dataset.cleanup_cache_files()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = dataset.cast_column(\"mortality_outcome\", ClassLabel(num_classes=2))\n", + "dataset = dataset.train_test_split(\n", + " train_size=TRAIN_SIZE,\n", + " stratify_by_column=\"mortality_outcome\",\n", + " seed=RANDOM_SEED,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Creation\n", + "\n", + "CyclOps model registry allows for straightforward creation and selection of models. This registry maintains a list of pre-configured models, which can be instantiated with a single line of code. Here we use a XGBoost classifier to fit a logisitic regression model. The model configurations can be passed to `create_model` based on the parameters for XGBClassifer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"xgb_classifier\"\n", + "model = create_model(model_name, random_state=123)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Task Creation\n", + "\n", + "We use Cyclops tasks to define our model's task (in this case, BinaryTabularClassificationTask), train the model, make predictions, and evaluate performance. Cyclops task classes encapsulate the entire ML pipeline into a single, cohesive structure, making the process smooth and easy to manage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mortality_task = BinaryTabularClassificationTask(\n", + " {model_name: model},\n", + " task_features=features_list,\n", + " task_target=\"mortality_outcome\",\n", + ")\n", + "mortality_task.list_models()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training\n", + "\n", + "If `best_model_params` is passed to the `train` method, the best model will be selected after the hyperparameter search. The parameters in `best_model_params` indicate the values to create the parameters grid.\n", + "\n", + "Note that the data preprocessor needs to be passed to the tasks methods if the Hugging Face dataset is not already preprocessed. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "best_model_params = {\n", + " \"n_estimators\": [100, 250, 500],\n", + " \"learning_rate\": [0.1, 0.01],\n", + " \"max_depth\": [2, 5],\n", + " \"reg_lambda\": [0, 1, 10],\n", + " \"colsample_bytree\": [0.7, 0.8, 1],\n", + " \"gamma\": [0, 1, 2, 10],\n", + " \"method\": \"random\",\n", + "}\n", + "mortality_task.train(\n", + " dataset[\"train\"],\n", + " model_name=model_name,\n", + " transforms=preprocessor,\n", + " best_model_params=best_model_params,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_params = mortality_task.list_models_params()[model_name]\n", + "print(model_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Log the model parameters to the report.**\n", + "\n", + "We can add model parameters to the model card using the `log_model_parameters` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.log_model_parameters(params=model_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prediction\n", + "\n", + "The prediction output can be either the whole Hugging Face dataset with the prediction columns added to it or the single column containing the predicted values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_pred = mortality_task.predict(\n", + " dataset[\"test\"],\n", + " model_name=model_name,\n", + " transforms=preprocessor,\n", + " proba=False,\n", + " only_predictions=True,\n", + ")\n", + "print(len(y_pred))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation\n", + "\n", + "Evaluation is done using various evaluation metrics that provide different perspectives on the model's predictive abilities i.e. standard performance metrics and fairness metrics.\n", + "\n", + "The standard performance metrics can be created using the `MetricCollection` object." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metric_names = [\n", + " \"accuracy\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"f1_score\",\n", + " \"auroc\",\n", + " \"roc_curve\",\n", + " \"precision_recall_curve\",\n", + "]\n", + "metrics = [create_metric(metric_name, task=\"binary\") for metric_name in metric_names]\n", + "metric_collection = MetricCollection(metrics)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In addition to overall metrics, it might be interesting to see how the model performs on certain subpopulations. We can define these subpopulations using `SliceSpec` objects. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "spec_list = [\n", + " {\n", + " \"age\": {\n", + " \"min_value\": 20,\n", + " \"max_value\": 50,\n", + " \"min_inclusive\": True,\n", + " \"max_inclusive\": False,\n", + " },\n", + " },\n", + " {\n", + " \"age\": {\n", + " \"min_value\": 50,\n", + " \"max_value\": 80,\n", + " \"min_inclusive\": True,\n", + " \"max_inclusive\": False,\n", + " },\n", + " },\n", + " {\"gender\": {\"value\": \"M\"}},\n", + " {\"gender\": {\"value\": \"F\"}},\n", + "]\n", + "slice_spec = SliceSpec(spec_list)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A `MetricCollection` can also be defined for the fairness metrics." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "specificity = create_metric(\n", + " metric_name=\"specificity\",\n", + " task=\"binary\",\n", + ")\n", + "sensitivity = create_metric(\n", + " metric_name=\"sensitivity\",\n", + " task=\"binary\",\n", + ")\n", + "fpr = 1 - specificity\n", + "fnr = 1 - sensitivity\n", + "ber = (fpr + fnr) / 2\n", + "fairness_metric_collection = MetricCollection(\n", + " {\n", + " \"Sensitivity\": sensitivity,\n", + " \"Specificity\": specificity,\n", + " \"BER\": ber,\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The FairnessConfig helps in setting up and evaluating the fairness of the model predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fairness_config = FairnessConfig(\n", + " metrics=fairness_metric_collection,\n", + " dataset=None, # dataset is passed from the evaluator\n", + " target_columns=None, # target columns are passed from the evaluator\n", + " groups=[\"gender\", \"age\"],\n", + " group_bins={\"age\": [20, 40]},\n", + " group_base_values={\"age\": 40, \"gender\": \"M\"},\n", + " thresholds=[0.5],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The evaluate methods outputs the evaluation results and the Hugging Face dataset with the predictions added to it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results, dataset_with_preds = mortality_task.evaluate(\n", + " dataset[\"test\"],\n", + " metric_collection,\n", + " model_names=model_name,\n", + " transforms=preprocessor,\n", + " prediction_column_prefix=\"preds\",\n", + " slice_spec=slice_spec,\n", + " batch_size=64,\n", + " fairness_config=fairness_config,\n", + " override_fairness_metrics=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Log the performance metrics to the report.**\n", + "\n", + "We can add a performance metric to the model card using the `log_performance_metric` method, which expects a dictionary where the keys are in the following format: `slice_name/metric_name`. For instance, `overall/accuracy`. \n", + "\n", + "We first need to process the evaluation results to get the metrics in the right format." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results_flat = flatten_results_dict(\n", + " results=results,\n", + " remove_metrics=[\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"],\n", + " model_name=model_name,\n", + ")\n", + "print(results_flat)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for name, metric in results_flat.items():\n", + " split, name = name.split(\"/\") # noqa: PLW2901\n", + " descriptions = {\n", + " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", + " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", + " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", + " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", + " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", + " }\n", + " report.log_quantitative_analysis(\n", + " \"performance\",\n", + " name=name,\n", + " value=metric,\n", + " description=descriptions[name],\n", + " metric_slice=split,\n", + " pass_fail_thresholds=0.7,\n", + " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also use the `ClassificationPlotter` to plot the performance metrics and the add the figure to the model card using the `log_plotly_figure` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter = ClassificationPlotter(task_type=\"binary\", class_names=[\"0\", \"1\"])\n", + "plotter.set_template(\"plotly_white\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# extracting the ROC curves and AUROC results for all the slices\n", + "roc_curves = {\n", + " slice_name: slice_results[\"BinaryROCCurve\"]\n", + " for slice_name, slice_results in results[model_name].items()\n", + "}\n", + "aurocs = {\n", + " slice_name: slice_results[\"BinaryAUROC\"]\n", + " for slice_name, slice_results in results[model_name].items()\n", + "}\n", + "roc_curves.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plotting the ROC curves for all the slices\n", + "roc_plot = plotter.roc_curve_comparison(roc_curves, aurocs=aurocs)\n", + "report.log_plotly_figure(\n", + " fig=roc_plot,\n", + " caption=\"ROC Curve for Female Patients\",\n", + " section_name=\"quantitative analysis\",\n", + ")\n", + "roc_plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Extracting the overall classification metric values.\n", + "overall_performance = {\n", + " metric_name: metric_value\n", + " for metric_name, metric_value in results[model_name][\"overall\"].items()\n", + " if metric_name not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"]\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plotting the overall classification metric values.\n", + "overall_performance_plot = plotter.metrics_value(\n", + " overall_performance,\n", + " title=\"Overall Performance\",\n", + ")\n", + "report.log_plotly_figure(\n", + " fig=overall_performance_plot,\n", + " caption=\"Overall Performance\",\n", + " section_name=\"quantitative analysis\",\n", + ")\n", + "overall_performance_plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Extracting the metric values for all the slices.\n", + "slice_metrics = {\n", + " slice_name: {\n", + " metric_name: metric_value\n", + " for metric_name, metric_value in slice_results.items()\n", + " if metric_name not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"]\n", + " }\n", + " for slice_name, slice_results in results[model_name].items()\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plotting the metric values for all the slices.\n", + "slice_metrics_plot = plotter.metrics_comparison_bar(slice_metrics)\n", + "report.log_plotly_figure(\n", + " fig=slice_metrics_plot,\n", + " caption=\"Slice Metric Comparison\",\n", + " section_name=\"quantitative analysis\",\n", + ")\n", + "slice_metrics_plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Reformating the fairness metrics\n", + "fairness_results = copy.deepcopy(results[\"fairness\"])\n", + "fairness_metrics = {}\n", + "# remove the group size from the fairness results and add it to the slice name\n", + "for slice_name, slice_results in fairness_results.items():\n", + " group_size = slice_results.pop(\"Group Size\")\n", + " fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plotting the fairness metrics\n", + "fairness_plot = plotter.metrics_comparison_scatter(\n", + " fairness_metrics,\n", + " title=\"Fairness Metrics\",\n", + ")\n", + "report.log_plotly_figure(\n", + " fig=fairness_plot,\n", + " caption=\"Fairness Metrics\",\n", + " section_name=\"fairness analysis\",\n", + ")\n", + "fairness_plot.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Report Generation\n", + "\n", + "Before generating the model card, let us document some of the details of the model and some considerations involved in developing and using the model.\n", + "\n", + "\n", + "Let's start with populating the model details section, which includes the following fields by default:\n", + "- description: A high-level description of the model and its usage for a general audience.\n", + "- version: The version of the model.\n", + "- owners: The individuals or organizations that own the model.\n", + "- license: The license under which the model is made available.\n", + "- citation: The citation for the model.\n", + "- references: Links to resources that are relevant to the model.\n", + "- path: The path to where the model is stored.\n", + "- regulatory_requirements: The regulatory requirements that are relevant to the model.\n", + "\n", + "We can add additional fields to the model details section by passing a dictionary to the `log_from_dict` method and specifying the section name as `model_details`. You can also use the `log_descriptor` method to add a new field object with a `description` attribute to any section of the model card." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.log_from_dict(\n", + " data={\n", + " \"name\": \"Mortality Prediction Model\",\n", + " \"description\": \"The model was trained on the MIMICIV dataset \\\n", + " to predict risk of in-hospital mortality.\",\n", + " },\n", + " section_name=\"model_details\",\n", + ")\n", + "report.log_version(\n", + " version_str=\"0.0.1\",\n", + " date=str(date.today()),\n", + " description=\"Initial Release\",\n", + ")\n", + "report.log_owner(\n", + " name=\"CyclOps Team\",\n", + " contact=\"vectorinstitute.github.io/cyclops/\",\n", + " email=\"cyclops@vectorinstitute.ai\",\n", + ")\n", + "report.log_license(identifier=\"Apache-2.0\")\n", + "report.log_reference(\n", + " link=\"https://xgboost.readthedocs.io/en/stable/python/python_api.html\", # noqa: E501\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let's populate the considerations section, which includes the following fields by default:\n", + "- users: The intended users of the model.\n", + "- use_cases: The use cases for the model. These could be primary, downstream or out-of-scope use cases.\n", + "- fairness_assessment: A description of the benefits and harms of the model for different groups as well as the steps taken to mitigate the harms.\n", + "- ethical_considerations: The risks associated with using the model and the steps taken to mitigate them. This can be populated using the `log_risk` method.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.log_from_dict(\n", + " data={\n", + " \"users\": [\n", + " {\"description\": \"Hospitals\"},\n", + " {\"description\": \"Clinicians\"},\n", + " ],\n", + " },\n", + " section_name=\"considerations\",\n", + ")\n", + "report.log_user(description=\"ML Engineers\")\n", + "report.log_use_case(\n", + " description=\"Predicting prolonged length of stay\",\n", + " kind=\"primary\",\n", + ")\n", + "report.log_fairness_assessment(\n", + " affected_group=\"sex, age\",\n", + " benefit=\"Improved health outcomes for patients.\",\n", + " harm=\"Biased predictions for patients in certain groups (e.g. older patients) \\\n", + " may lead to worse health outcomes.\",\n", + " mitigation_strategy=\"We will monitor the performance of the model on these groups \\\n", + " and retrain the model if the performance drops below a certain threshold.\",\n", + ")\n", + "report.log_risk(\n", + " risk=\"The model may be used to make decisions that affect the health of patients.\",\n", + " mitigation_strategy=\"The model should be continuously monitored for performance \\\n", + " and retrained if the performance drops below a certain threshold.\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the model card is populated, you can generate the report using the `export` method. The report is generated in the form of an HTML file. A JSON file containing the model card data will also be generated along with the HTML file. By default, the files will be saved in a folder named `cyclops_reports` in the current working directory. You can change the path by passing a `output_dir` argument when instantiating the `ModelCardReport` class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_path = report.export(output_filename=\"mortality_report_periodic.html\")\n", + "shutil.copy(f\"{report_path}\", \".\")\n", + "\n", + "for _ in range(5):\n", + " report._model_card.overview = None\n", + " report._model_card.quantitative_analysis = None\n", + " results_flat = flatten_results_dict(\n", + " results=results,\n", + " remove_metrics=[\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"],\n", + " model_name=model_name,\n", + " )\n", + "\n", + " for name, metric in results_flat.items():\n", + " split, name = name.split(\"/\") # noqa: PLW2901\n", + " descriptions = {\n", + " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", + " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", + " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", + " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", + " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", + " }\n", + " report.log_quantitative_analysis(\n", + " \"performance\",\n", + " name=name,\n", + " value=np.clip(metric + np.random.normal(0, 0.1), 0, 1),\n", + " description=descriptions[name],\n", + " metric_slice=split,\n", + " pass_fail_thresholds=0.7,\n", + " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", + " )\n", + " report_path = report.export(output_filename=\"length_of_stay_report_periodic.html\")\n", + " shutil.copy(f\"{report_path}\", \".\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can view the generated HTML [report](./mortality_report_periodic.html)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/source/tutorials/synthea/los_prediction.ipynb b/docs/source/tutorials/synthea/los_prediction.ipynb index ac3c763f8..108a75748 100644 --- a/docs/source/tutorials/synthea/los_prediction.ipynb +++ b/docs/source/tutorials/synthea/los_prediction.ipynb @@ -235,6 +235,7 @@ " \"encounter_id\",\n", " {\"description\": (\"count\", \"n_meds\")},\n", " )\n", + "\n", " return cohort.ops(groupby_op).run()\n", "\n", "\n", @@ -249,6 +250,7 @@ " \"encounter_id\",\n", " {\"description\": (\"count\", \"n_procedures\")},\n", " )\n", + "\n", " return cohort.ops(groupby_op).run()\n", "\n", "\n", @@ -269,6 +271,7 @@ " on=\"encounter_id\",\n", " how=\"left\",\n", " )\n", + "\n", " return cohort\n", "\n", "\n", @@ -285,18 +288,6 @@ "### Drop NaNs based on the `NAN_THRESHOLD`" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "c095179e-ea22-4549-9e65-cfce72f441bc", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "querier.list_columns(\"native\", \"encounters\")" - ] - }, { "cell_type": "code", "execution_count": null, @@ -308,14 +299,12 @@ "source": [ "null_counts = cohort.isnull().sum()[cohort.isnull().sum() > 0]\n", "fig = go.Figure(data=[go.Bar(x=null_counts.index, y=null_counts.values)])\n", - "\n", "fig.update_layout(\n", " title=\"Number of Null Values per Column\",\n", " xaxis_title=\"Columns\",\n", " yaxis_title=\"Number of Null Values\",\n", " height=600,\n", ")\n", - "\n", "fig.show()" ] }, @@ -496,11 +485,9 @@ "outputs": [], "source": [ "fig = px.pie(cohort, names=\"gender\")\n", - "\n", "fig.update_layout(\n", " title=\"Gender Distribution\",\n", ")\n", - "\n", "fig.show()" ] }, @@ -554,7 +541,6 @@ " yaxis_title=\"Count\",\n", " bargap=0.2,\n", ")\n", - "\n", "fig.show()" ] }, @@ -775,7 +761,7 @@ "source": [ "## Model Creation\n", "\n", - "CyclOps model registry allows for straightforward creation and selection of models. This registry maintains a list of pre-configured models, which can be instantiated with a single line of code. Here we use a SGD classifier to fit a logisitic regression model. The model configurations can be passed to `create_model` based on the sllearn parameters for SGDClassifer." + "CyclOps model registry allows for straightforward creation and selection of models. This registry maintains a list of pre-configured models, which can be instantiated with a single line of code. Here we use a XGBoost classifier to fit a logisitic regression model. The model configurations can be passed to `create_model` based on the parameters for XGBClassifer." ] }, { @@ -798,7 +784,7 @@ "source": [ "## Task Creation\n", "\n", - "We use Cyclops tasks to define our model's task (in this case, MortalityPrediction), train the model, make predictions, and evaluate performance. Cyclops task classes encapsulate the entire ML pipeline into a single, cohesive structure, making the process smooth and easy to manage." + "We use Cyclops tasks to define our model's task (in this case, BinaryTabularClassificationTask), train the model, make predictions, and evaluate performance. Cyclops task classes encapsulate the entire ML pipeline into a single, cohesive structure, making the process smooth and easy to manage." ] }, { @@ -814,18 +800,7 @@ " {model_name: model},\n", " task_features=features_list,\n", " task_target=\"outcome\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "96323637-f5ea-41dd-8899-ce6680d5d58d", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ + ")\n", "los_task.list_models()" ] }, @@ -1028,12 +1003,9 @@ " metric_name=\"sensitivity\",\n", " task=\"binary\",\n", ")\n", - "\n", "fpr = 1 - specificity\n", "fnr = 1 - sensitivity\n", - "\n", "ber = (fpr + fnr) / 2\n", - "\n", "fairness_metric_collection = MetricCollection(\n", " {\n", " \"Sensitivity\": sensitivity,\n", diff --git a/docs/source/tutorials_use_cases.rst b/docs/source/tutorials_use_cases.rst index a28e58af5..4c2def05d 100644 --- a/docs/source/tutorials_use_cases.rst +++ b/docs/source/tutorials_use_cases.rst @@ -16,6 +16,15 @@ variable. tutorials/kaggle/heart_failure_prediction.ipynb +MIMICIV Mortality Prediction +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This is a binary classification problem where the goal is to predict +risk of in-hospital mortality. The `MIMICIV dataset `_ is an EHR dataset collected from a single hospital site, which includes ICU data. + +.. toctree:: + + tutorials/mimiciv/mortality_prediction.ipynb Synthea Prolonged Length of Stay Prediction ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/poetry.lock b/poetry.lock index d6e051ef7..a370b2892 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4273,13 +4273,13 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa [[package]] name = "platformdirs" -version = "3.11.0" +version = "4.0.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." optional = false python-versions = ">=3.7" files = [ - {file = "platformdirs-3.11.0-py3-none-any.whl", hash = "sha256:e9d171d00af68be50e9202731309c4e658fd8bc76f55c11c7dd760d023bda68e"}, - {file = "platformdirs-3.11.0.tar.gz", hash = "sha256:cf8ee52a3afdb965072dcc652433e0c7e3e40cf5ea1477cd4b3b1d2eb75495b3"}, + {file = "platformdirs-4.0.0-py3-none-any.whl", hash = "sha256:118c954d7e949b35437270383a3f2531e99dd93cf7ce4dc8340d3356d30f173b"}, + {file = "platformdirs-4.0.0.tar.gz", hash = "sha256:cb633b2bcf10c51af60beb0ab06d2f1d69064b43abf4c185ca6b28865f3f9731"}, ] [package.extras] @@ -4458,6 +4458,8 @@ files = [ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, + {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, + {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, @@ -4531,13 +4533,13 @@ numpy = ">=1.16.6" [[package]] name = "pyarrow-hotfix" -version = "0.5" +version = "0.6" description = "" optional = false python-versions = ">=3.5" files = [ - {file = "pyarrow_hotfix-0.5-py3-none-any.whl", hash = "sha256:7e20a1195f2e0dd7b50dffb9f90699481acfce3176bfbfb53eded04f34c4f7c6"}, - {file = "pyarrow_hotfix-0.5.tar.gz", hash = "sha256:ba697c743d435545e99bfbd89818b284e4404c19119c0ed63380a92998c4d0b1"}, + {file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"}, + {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, ] [[package]] @@ -4646,13 +4648,13 @@ files = [ [[package]] name = "pygments" -version = "2.17.1" +version = "2.17.2" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.7" files = [ - {file = "pygments-2.17.1-py3-none-any.whl", hash = "sha256:1b37f1b1e1bff2af52ecaf28cc601e2ef7077000b227a0675da25aef85784bc4"}, - {file = "pygments-2.17.1.tar.gz", hash = "sha256:e45a0e74bf9c530f564ca81b8952343be986a29f6afe7f5ad95c5f06b7bdf5e8"}, + {file = "pygments-2.17.2-py3-none-any.whl", hash = "sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c"}, + {file = "pygments-2.17.2.tar.gz", hash = "sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367"}, ] [package.extras] @@ -5752,13 +5754,13 @@ win32 = ["pywin32"] [[package]] name = "setuptools" -version = "69.0.1" +version = "69.0.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-69.0.1-py3-none-any.whl", hash = "sha256:6875bbd06382d857b1b90cd07cee6a2df701a164f241095706b5192bc56c5c62"}, - {file = "setuptools-69.0.1.tar.gz", hash = "sha256:f25195d54deb649832182d6455bffba7ac3d8fe71d35185e738d2198a4310044"}, + {file = "setuptools-69.0.2-py3-none-any.whl", hash = "sha256:1e8fdff6797d3865f37397be788a4e3cba233608e9b509382a2777d25ebde7f2"}, + {file = "setuptools-69.0.2.tar.gz", hash = "sha256:735896e78a4742605974de002ac60562d286fa8051a7e2299445e8e8fbb01aa6"}, ] [package.extras] @@ -7111,19 +7113,19 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "virtualenv" -version = "20.24.6" +version = "20.24.7" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.24.6-py3-none-any.whl", hash = "sha256:520d056652454c5098a00c0f073611ccbea4c79089331f60bf9d7ba247bb7381"}, - {file = "virtualenv-20.24.6.tar.gz", hash = "sha256:02ece4f56fbf939dbbc33c0715159951d6bf14aaf5457b092e4548e1382455af"}, + {file = "virtualenv-20.24.7-py3-none-any.whl", hash = "sha256:a18b3fd0314ca59a2e9f4b556819ed07183b3e9a3702ecfe213f593d44f7b3fd"}, + {file = "virtualenv-20.24.7.tar.gz", hash = "sha256:69050ffb42419c91f6c1284a7b24e0475d793447e35929b488bf6a0aade39353"}, ] [package.dependencies] distlib = ">=0.3.7,<1" filelock = ">=3.12.2,<4" -platformdirs = ">=3.9.1,<4" +platformdirs = ">=3.9.1,<5" [package.extras] docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] @@ -7145,13 +7147,13 @@ colorama = {version = ">=0.4.6", markers = "sys_platform == \"win32\" and python [[package]] name = "wcwidth" -version = "0.2.11" +version = "0.2.12" description = "Measures the displayed width of unicode strings in a terminal" optional = false python-versions = "*" files = [ - {file = "wcwidth-0.2.11-py2.py3-none-any.whl", hash = "sha256:c4b153acf29f1f0d7fb1b00d097cce82b73de7a2016321c8d7ca71bd76dd848b"}, - {file = "wcwidth-0.2.11.tar.gz", hash = "sha256:25eb3ecbec328cdb945f56f2a7cfe784bdf7a73a8197398c7a7c65e7fe93e9ae"}, + {file = "wcwidth-0.2.12-py2.py3-none-any.whl", hash = "sha256:f26ec43d96c8cbfed76a5075dac87680124fa84e0855195a6184da9c187f133c"}, + {file = "wcwidth-0.2.12.tar.gz", hash = "sha256:f01c104efdf57971bcb756f054dd58ddec5204dd15fa31d6503ea57947d97c02"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index 04368e92f..49832b547 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -175,7 +175,6 @@ ignore = [ # Ignore import violations in all `__init__.py` files. [tool.ruff.per-file-ignores] "__init__.py" = ["E402", "F401", "F403", "F811"] -"use_cases/data_processors/mimiciv.py" = ["D417", "N806", "N803"] [tool.ruff.pep8-naming] ignore-names = ["X*", "setUp"] diff --git a/use_cases/__init__.py b/use_cases/__init__.py deleted file mode 100644 index 04d7f4a6b..000000000 --- a/use_cases/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Use-case implementations.""" diff --git a/use_cases/common/util.py b/use_cases/common/util.py deleted file mode 100644 index c1e5b01ec..000000000 --- a/use_cases/common/util.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Utility functions shared across use-cases.""" - -import importlib -import types - - -def get_use_case_params(dataset: str, use_case: str) -> types.ModuleType: - """Import parameters specific to each use-case. - - Parameters - ---------- - dataset: str - Name of the dataset, e.g. mimiciv. - use_case: str - Name of the use-case, e.g. mortality_decompensation. - - Returns - ------- - types.ModuleType - Imported constants module with use-case parameters. - - """ - return importlib.import_module( - ".".join(["use_cases", "params", dataset, use_case, "constants_v1"]), - ) diff --git a/use_cases/data_collectors/mimiciv/mortality_decompensation.ipynb b/use_cases/data_collectors/mimiciv/mortality_decompensation.ipynb deleted file mode 100644 index 2f123c708..000000000 --- a/use_cases/data_collectors/mimiciv/mortality_decompensation.ipynb +++ /dev/null @@ -1,259 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "c19e50b7-6ff0-48cd-a844-00d9fa4c4606", - "metadata": {}, - "source": [ - "# Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7e7ef0f0-c03d-454e-ae99-2ad78e65a7b4", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "from cyclops.process.clean import normalize_categories, normalize_names\n", - "from cyclops.process.column_names import (\n", - " ENCOUNTER_ID,\n", - " EVENT_CATEGORY,\n", - " EVENT_NAME,\n", - " EVENT_TIMESTAMP,\n", - ")\n", - "from cyclops.query import process as qp\n", - "from cyclops.query.mimiciv import MIMICIVQuerier\n", - "from cyclops.utils.common import add_years_approximate\n", - "from cyclops.utils.file import join, load_dataframe, save_dataframe, yield_dataframes\n", - "from use_cases.params.mimiciv.mortality_decompensation.constants import (\n", - " CLEANED_DIR,\n", - " ENCOUNTERS_FILE,\n", - " OUTCOME_DEATH,\n", - " QUERIED_DIR,\n", - ")\n", - "\n", - "\n", - "mimic = MIMICIVQuerier()" - ] - }, - { - "cell_type": "markdown", - "id": "2da0063b-7881-475f-a5ca-a267c0d9fb0d", - "metadata": {}, - "source": [ - "# Query" - ] - }, - { - "cell_type": "markdown", - "id": "0d3412ae-a3a2-49b1-99d7-f13014cb3d35", - "metadata": { - "tags": [] - }, - "source": [ - "## Patient encounters" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fd7883ba-5319-4035-91fb-0128fb36821d", - "metadata": {}, - "outputs": [], - "source": [ - "encounters_interface = mimic.patient_encounters()\n", - "\n", - "encounters_query = encounters_interface.query\n", - "encounters_query = qp.Drop(\n", - " [\"insurance\", \"language\", \"marital_status\", \"edregtime\", \"edouttime\"],\n", - ")(encounters_query)\n", - "\n", - "encounters_interface = mimic.get_interface(encounters_query)\n", - "encounters = encounters_interface.run()\n", - "encounters.head(5)" - ] - }, - { - "cell_type": "markdown", - "id": "5371359b-249d-45b2-8384-2a4757cf1d70", - "metadata": {}, - "source": [ - "Create death indicator\n", - "\n", - "Hospital expire flag:\n", - " - 1 - Death in hospital\n", - " - 0 - Survived past discharge" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "449324b6-09a1-4e14-9864-9635f61eb635", - "metadata": {}, - "outputs": [], - "source": [ - "# Drop encounters ending in death which don't have a death timestamp\n", - "invalid = (encounters[\"hospital_expire_flag\"] == 1) & (encounters[\"deathtime\"].isna())\n", - "encounters = encounters[~invalid]\n", - "\n", - "# (Died in hospital) & (Death timestamp is defined)\n", - "encounters[OUTCOME_DEATH] = encounters[\"hospital_expire_flag\"] == 1\n", - "encounters.head(5)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d2aca418-e3e8-468e-ade0-a39f415240c3", - "metadata": {}, - "outputs": [], - "source": [ - "(encounters[OUTCOME_DEATH] == True).sum() / len(encounters) # noqa: E712" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e97d4889-d8ac-4648-b60c-3cdd18854567", - "metadata": {}, - "outputs": [], - "source": [ - "save_dataframe(encounters, ENCOUNTERS_FILE)" - ] - }, - { - "cell_type": "markdown", - "id": "a176f2de-f264-4253-ab11-da9017a4153c", - "metadata": {}, - "source": [ - "## Events" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "de68c082-17d8-46bc-b9ce-4b19b0a97fcd", - "metadata": {}, - "outputs": [], - "source": [ - "events_interface = mimic.events()\n", - "events_query = events_interface.query\n", - "events_query = qp.Drop([\"warning\", \"itemid\", \"storetime\"])(events_query)\n", - "events_interface = mimic.get_interface(events_query)\n", - "events_interface.save_in_grouped_batches(QUERIED_DIR, ENCOUNTER_ID, int(1e6))" - ] - }, - { - "cell_type": "markdown", - "id": "7fdec0d6-5170-486b-85b8-c0b4294a3358", - "metadata": {}, - "source": [ - "# Clean / Preprocess" - ] - }, - { - "cell_type": "markdown", - "id": "de37e20f-5e00-4fcb-8f68-c4675de2c7b7", - "metadata": {}, - "source": [ - "Can be run entirely separately from the querying." - ] - }, - { - "cell_type": "markdown", - "id": "1a2b11e1-c784-4673-83e8-5fb456efab56", - "metadata": {}, - "source": [ - "## Patient encounters" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7104dacf-c4e0-4083-9e81-c83ea701e6d3", - "metadata": {}, - "outputs": [], - "source": [ - "encounters = load_dataframe(ENCOUNTERS_FILE)" - ] - }, - { - "cell_type": "markdown", - "id": "76eb03ba-8d69-4498-ae08-99f82d2a66ea", - "metadata": {}, - "source": [ - "## Events" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aa554f42-4bc1-42dd-91b8-d02b2f2c3f5c", - "metadata": {}, - "outputs": [], - "source": [ - "skip_n = 0\n", - "generator = yield_dataframes(QUERIED_DIR, skip_n=skip_n, log=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1c7c0e9a-b72d-4e3f-9ac2-802eb354d3e7", - "metadata": {}, - "outputs": [], - "source": [ - "for save_count, events in enumerate(generator):\n", - " events = events.drop([\"stay_id\"], axis=1)\n", - "\n", - " # Reverse deidentified dating\n", - " events = pd.merge(\n", - " encounters[[ENCOUNTER_ID, \"anchor_year_difference\"]],\n", - " events,\n", - " on=ENCOUNTER_ID,\n", - " )\n", - " events[EVENT_TIMESTAMP] = add_years_approximate(\n", - " events[EVENT_TIMESTAMP],\n", - " events[\"anchor_year_difference\"],\n", - " )\n", - " events = events.drop(\"anchor_year_difference\", axis=1)\n", - "\n", - " # Preprocessing\n", - " events[EVENT_NAME] = normalize_names(events[EVENT_NAME])\n", - " events[EVENT_CATEGORY] = normalize_categories(events[EVENT_CATEGORY])\n", - "\n", - " # Concatenate event name and category since some names are the same in\n", - " # different categories, e.g., 'flow' for categories 'heartware' and 'ecmo'\n", - " events[EVENT_NAME] = events[EVENT_CATEGORY] + \" - \" + events[EVENT_NAME]\n", - " events.head(5)\n", - "\n", - " save_dataframe(events, join(CLEANED_DIR, \"batch_\" + f\"{save_count + skip_n:04d}\"))\n", - " del events" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "cyclops", - "language": "python", - "name": "cyclops" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/use_cases/data_processors/mimiciv.py b/use_cases/data_processors/mimiciv.py deleted file mode 100644 index ef6904627..000000000 --- a/use_cases/data_processors/mimiciv.py +++ /dev/null @@ -1,1476 +0,0 @@ -"""MIMICIV processor.""" - -import logging -from os import path -from typing import Callable, Generator, List, Optional, Tuple - -import numpy as np -import pandas as pd - -from cyclops.process.aggregate import ( - Aggregator, - tabular_as_aggregated, - timestamp_ffill_agg, -) -from cyclops.process.constants import FEATURES, NUMERIC, ORDINAL, TARGETS -from cyclops.process.feature.feature import TabularFeatures, TemporalFeatures -from cyclops.process.feature.vectorized import ( - Vectorized, - intersect_vectorized, - split_vectorized, -) -from cyclops.query import mimiciv as mimic -from cyclops.utils.file import ( - join, - load_dataframe, - load_pickle, - save_dataframe, - save_pickle, - yield_dataframes, - yield_pickled_files, -) -from cyclops.utils.log import setup_logging -from use_cases.util import get_top_events, get_use_case_params, valid_events - - -LOGGER = logging.getLogger(__name__) -setup_logging(print_level="INFO", logger=LOGGER) - - -class MIMICIVProcessor: - """MIMICIV processor.""" - - def __init__(self, use_case: str, data_type: str) -> None: - """Init processor. - - Parameters - ---------- - use_case : str - The use-case to process the data for. - - data_type : str - Type of data (tabular, temporal, or combined). - - """ - self.params = get_use_case_params("mimiciv", use_case) - self.data_type = data_type - self._setup_paths() - self._setup_params() - self.aggregator = self._init_aggregator() - - ################### - # Init methods - ################### - - def _setup_paths(self) -> None: - """Set up paths and dirs.""" - self.data_dir = self.params.DATA_DIR - - self.queried_dir = self.params.QUERIED_DIR - self.cleaned_dir = self.params.CLEANED_DIR - self.aggregated_dir = self.params.AGGREGATED_DIR - self.vectorized_dir = self.params.VECTORIZED_DIR - self.final_vectorized = self.params.FINAL_VECTORIZED - - self.tabular_file = self.params.TABULAR_FILE - self.tab_features_file = self.params.TAB_FEATURES_FILE - self.tab_slice_file = self.params.TAB_SLICE_FILE - self.tab_aggregated_file = self.params.TAB_AGGREGATED_FILE - self.tab_vectorized_file = self.params.TAB_VECTORIZED_FILE - self.temp_vectorized_file = self.params.TEMP_VECTORIZED_FILE - self.comb_vectorized_file = self.params.COMB_VECTORIZED_FILE - - self.aligned_path = self.params.ALIGNED_PATH - self.unaligned_path = self.params.UNALIGNED_PATH - - def _setup_params(self) -> None: - """Set up the data processing parameters.""" - self.common_feature = ( - self.params.COMMON_FEATURE if self.params.COMMON_FEATURE else None - ) - self.skip_n = self.params.SKIP_N if self.params.SKIP_N else 0 - self.split_fractions = ( - self.params.SPLIT_FRACTIONS if self.params.SPLIT_FRACTIONS else None - ) - self.tab_feature_params = ( - self.params.TABULAR_FEATURES if self.params.TABULAR_FEATURES else None - ) - self.tab_norm_params = ( - self.params.TABULAR_NORM if self.params.TABULAR_NORM else None - ) - self.tab_slice_params = ( - self.params.TABULAR_SLICE if self.params.TABULAR_SLICE else None - ) - self.tab_agg_params = ( - self.params.TABULAR_AGG if self.params.TABULAR_AGG else None - ) - - self.temp_params = ( - self.params.TEMPORAL_PARAMS if self.params.TEMPORAL_PARAMS else None - ) - self.temp_norm_params = ( - self.params.TEMPORAL_NORM if self.params.TEMPORAL_NORM else None - ) - self.temp_feature_params = ( - self.params.TEMPORAL_FEATURES if self.params.TEMPORAL_FEATURES else None - ) - self.temp_target_params = ( - self.params.TEMPORAL_TARGETS if self.params.TEMPORAL_TARGETS else None - ) - self.timestamp_params = ( - self.params.TIMESTAMPS if self.params.TIMESTAMPS else None - ) - self.timestep_params = self.params.TIMESTEPS if self.params.TIMESTEPS else None - self.temp_slice_params = ( - self.params.TEMPORAL_SLICE if self.params.TEMPORAL_SLICE else None - ) - self.temp_agg_params = ( - self.params.TEMPORAL_AGG if self.params.TEMPORAL_AGG else None - ) - self.temp_impute_params = ( - self.params.TEMPORAL_IMPUTE if self.params.TEMPORAL_IMPUTE else None - ) - - def _init_aggregator(self) -> Aggregator: - """Initialize the aggregator for temporal and combined processing. - - Returns - ------- - Aggregator - The aggregator object. - - """ - return Aggregator( - aggfuncs=self.temp_agg_params["aggfuncs"], - timestamp_col=self.temp_agg_params["timestamp_col"], - time_by=self.temp_agg_params["time_by"], - agg_by=self.temp_agg_params["agg_by"], - timestep_size=self.temp_agg_params["timestep_size"], - window_duration=self.temp_agg_params["window_duration"], - ) - - ################### - # Common methods - ################### - - def _load_batches(self, data_dir: str) -> Generator[pd.DataFrame, None, None]: - """Load the data files saved as dataframes. - - Parameters - ---------- - data_dir : str - The directory path of files. - - Yields - ------ - pandas.DataFrame - A DataFrame. - - """ - return yield_dataframes(data_dir, skip_n=self.skip_n, log=False) - - def _normalize(self, vectorized: Vectorized) -> Vectorized: - """Fit normalizer and normalize. - - Parameters - ---------- - vectorized : Vectorized - Vectorized data. - - Returns - ------- - Vectorized - Vectorized data after normalization. - - """ - vectorized.fit_normalizer() - vectorized.normalize() - return vectorized - - ################### - # Tabular methods - ################### - - def _load_cohort(self) -> pd.DataFrame: - """Load the tabular cohort. - - Returns - ------- - pd.DataFrame - Tabular data. - - """ - return load_dataframe(self.tabular_file) - - def _get_tabular_features(self, tab_data: pd.DataFrame) -> TabularFeatures: - """Get the tabular features as an object. - - Parameters - ---------- - tab_data : pd.DataFrame - Tabular data. - - Returns - ------- - TabularFeatures - The tabular features object. - - """ - tab_features = TabularFeatures( - data=tab_data, - features=self.tab_feature_params["features"], - by=self.tab_feature_params["primary_feature"], - force_types=self.tab_feature_params["features_types"], - ) - save_pickle(tab_features, self.tab_features_file) - return tab_features - - def _slice_tabular(self, tab_features: TabularFeatures) -> np.ndarray: - """Slice the tabular data. - - Parameters - ---------- - tab_features : TabularFeatures - The tabular features object. - - Returns - ------- - np.ndarray - Array of the values of the "by" column, in the sliced dataset. - - """ - sliced_tab = tab_features.slice( - slice_map=self.tab_slice_params["slice_map"], - slice_query=self.tab_slice_params["slice_query"], - replace=self.tab_slice_params["replace"], - ) - if self.tab_slice_params["replace"]: - save_pickle(tab_features, self.tab_slice_file) - return sliced_tab - - def _get_tab_ordinal(self, tab_features: TabularFeatures) -> List[str]: - """Get the names of ordinal features in the tabular data. - - Parameters - ---------- - tab_features : TabularFeatures - The tabular features object. - - Returns - ------- - List[str] - List of ordinal features. - - """ - return tab_features.features_by_type(ORDINAL) - - def _get_tab_numeric(self, tab_features: TabularFeatures) -> List[str]: - """Get the names of numeric features in the tabular data. - - Parameters - ---------- - tab_features : TabularFeatures - The tabular features object. - - Returns - ------- - List[str] - List of numeric features. - - """ - return tab_features.features_by_type(NUMERIC) - - def _vectorize_tabular( - self, - tab_features: TabularFeatures, - normalize: bool, - ) -> Vectorized: - """Vectorize the tabular data. - - Parameters - ---------- - tab_features : TabularFeatures - The tabular features object. - normalize : bool - Whether to normalize numeric features. - - Returns - ------- - Vectorized - Vectorized tabular data. - - """ - tab_vectorized = tab_features.vectorize( - to_binary_indicators=self._get_tab_ordinal(tab_features), - ) - - if normalize: - normalizer_map = { - feat: self.tab_norm_params["method"] - for feat in self._get_tab_numeric(tab_features) - } - tab_vectorized.add_normalizer( - FEATURES, - normalizer_map=normalizer_map, - ) - - save_pickle(tab_vectorized, self.tab_vectorized_file) - return tab_vectorized - - def _aggregate_tabular( - self, - tab_features: TabularFeatures, - temp_vectorized: Vectorized, - ) -> pd.DataFrame: - """Aggregate the tabular data to pose as timeseries. - - Parameters - ---------- - tab_features : TabularFeatures - The tabular features object. - temp_vectorized : Vectorized - Vectorized temporal data. - - Returns - ------- - pd.DataFrame - Aggregated tabular data. - - """ - tab = tab_features.get_data( - to_binary_indicators=self._get_tab_ordinal(tab_features), - ).reset_index() - - tab = tab[ - np.in1d( - tab[self.common_feature].values, - temp_vectorized.get_index(self.common_feature), - ) - ] - - tab_aggregated = tabular_as_aggregated( - tab=tab, - index=self.tab_agg_params["index"], - var_name=self.tab_agg_params["var_name"], - value_name=self.tab_agg_params["value_name"], - strategy=self.tab_agg_params["strategy"], - num_timesteps=self.temp_agg_params["window_duration"] - // self.temp_agg_params["timestep_size"], - ) - save_dataframe(tab_aggregated, self.tab_aggregated_file) - return tab_aggregated - - def _vectorize_agg_tabular(self, tab_aggregated: pd.DataFrame) -> Vectorized: - """Vectorize the aggregated tabular data. - - Parameters - ---------- - tab_aggregated : pd.DataFrame - Aggregated tabular data. - - Returns - ------- - Vectorized - Vectorized tabular data. - - """ - return self.aggregator.vectorize(tab_aggregated) - - def _split_tabular(self, tab_vectorized: Vectorized) -> Tuple: - """Split tabular data to train, validation, and test sets. - - Parameters - ---------- - tab_vectorized : Vectorized - Vectorized tabular data. - - Returns - ------- - Tuple - A tuple of datasets of splits. All splits are Vectorized objects. - - """ - fractions = self.split_fractions.copy() - tab_train, tab_val, tab_test = split_vectorized( - vecs=[tab_vectorized], - fractions=fractions, - axes=self.common_feature, - )[0] - return tab_train, tab_val, tab_test - - def _get_tab_train(self, tab_train: Vectorized, normalize: bool) -> Tuple: - """Get the tabular train features (normalized) and the targets. - - Parameters - ---------- - tab_train : Vectorized - Vectorized tabular data. - normalize : bool - Whether to normalize the numeric features. - - Returns - ------- - Tuple - Tuple of train features and targets. - - """ - tab_train_X, tab_train_y = tab_train.split_out( - FEATURES, - self.tab_feature_params[TARGETS], - ) - if normalize: - tab_train_X = self._normalize(tab_train_X) - - return tab_train_X, tab_train_y - - def _get_tab_val(self, tab_val: Vectorized, normalize: bool) -> Tuple: - """Get the tabular validation features (normalized) and the targets. - - Parameters - ---------- - tab_val : Vectorized - Vectorized tabular data. - normalize : bool - Whether to normalize the numeric features. - - Returns - ------- - Tuple - Tuple of validation features and targets. - - """ - tab_val_X, tab_val_y = tab_val.split_out( - FEATURES, - self.tab_feature_params[TARGETS], - ) - if normalize: - tab_val_X = self._normalize(tab_val_X) - return tab_val_X, tab_val_y - - def _get_tab_test(self, tab_test: Vectorized, normalize: bool): - """Get the tabular test features (normalized) and the targets. - - Parameters - ---------- - tab_test : Vectorized - Vectorized tabular data. - normalize : bool - Whether to normalize the numeric features. - - Returns - ------- - Tuple - Tuple of test features and targets. - - """ - tab_test_X, tab_test_y = tab_test.split_out( - FEATURES, - self.tab_feature_params[TARGETS], - ) - if normalize: - tab_test_X = self._normalize(tab_test_X) - return tab_test_X, tab_test_y - - def _save_tabular( - self, - tab_train_X: Vectorized, - tab_train_y: Vectorized, - tab_val_X: Vectorized, - tab_val_y: Vectorized, - tab_test_X: Vectorized, - tab_test_y: Vectorized, - aligned: bool, - ) -> None: - """Save the tabular features and targets for all data splits. - - Parameters - ---------- - tab_train_X : Vectorized - Vectorized tabular features from the train set. - tab_train_y : Vectorized - Vectorized tabular targets from the train set. - tab_val_X : Vectorized - Vectorized tabular features from the validation set. - tab_val_y : Vectorized - Vectorized tabular targets from the validation set. - tab_test_X : Vectorized - Vectorized tabular features from the test set. - tab_test_y : Vectorized - Vectorized tabular targets from the test set. - aligned : bool - Whether data is aligned with the temporal data. - - """ - vectorized = [ - (tab_train_X, "tab_train_X"), - (tab_train_y, "tab_train_y"), - (tab_val_X, "tab_val_X"), - (tab_val_y, "tab_val_y"), - (tab_test_X, "tab_test_X"), - (tab_test_y, "tab_test_y"), - ] - for vec, name in vectorized: - if aligned: - save_pickle(vec, self.aligned_path + name) - else: - save_pickle(vec, self.unaligned_path + name) - - #################### - # Temporal methods - #################### - - def _get_temporal_features(self, data: pd.DataFrame) -> TemporalFeatures: - """Get the temporal features as an object. - - Parameters - ---------- - data : pd.DataFrame - Temporal data. - - Returns - ------- - TemporalFeatures - The temporal features object. - - """ - return TemporalFeatures( - data, - features=self.temp_feature_params["features"], - by=self.temp_feature_params["groupby"], - timestamp_col=self.temp_feature_params["timestamp_col"], - aggregator=self.aggregator, - ) - - def _get_timestamps(self, data: Optional[pd.DataFrame] = None) -> pd.DataFrame: - """Get relevant timestamps either from tabular data or the input dataframe. - - Parameters - ---------- - data : Optional[pd.DataFrame], optional - The dataframe to extract the timestamp from, by default None. - - Returns - ------- - pd.DataFrame - Timestamps data. - - """ - if not data: - data = load_dataframe(self.tabular_file) - return data[self.timestamp_params["columns"]] - - def _get_start_timestamps(self) -> pd.DataFrame: - """Get relevant start timestamps e.g. hospital admission time. - - Returns - ------- - pd.DataFrame - The start timestamps. - - """ - timestamps = self._get_timestamps() - return ( - timestamps[self.timestamp_params["start_columns"]] - .set_index(self.timestamp_params["start_index"]) - .rename(self.timestamp_params["rename"], axis=1) - ) - - def _aggregate_temporal_batches( - self, - generator: Generator[pd.DataFrame, None, None], - filter_fn: Optional[Callable] = None, - ) -> None: - """Aggregate the temporal data saved in batches. - - Parameters - ---------- - generator : Generator[pd.DataFrame, None, None] - Generator to yield the saved data files. - filter_fn : Optional[Callable], optional - Filter the data records before aggregating, by default None. - - """ - start_timestamps = self._get_start_timestamps() - for save_count, batch in enumerate(generator): - if filter_fn: - batch = filter_fn(batch) - batch = batch.reset_index(drop=True) - temp_features = self._get_temporal_features(batch) - aggregated = temp_features.aggregate(window_start_time=start_timestamps) - save_dataframe( - aggregated, - join(self.aggregated_dir, "batch_" + f"{save_count + self.skip_n:04d}"), - ) - del batch - - def _vectorize_temporal_batches(self, generator: Generator) -> None: - """Vectorize the temporal features saved in batches. - - Parameters - ---------- - generator : Generator - Generator to yield the saved data files. - - """ - for save_count, batch in enumerate(generator): - vec = self.aggregator.vectorize(batch) - save_pickle( - vec, - join(self.vectorized_dir, "batch_" + f"{save_count + self.skip_n:04d}"), - ) - - def _vectorize_temporal_features( - self, - generator: Generator[pd.DataFrame, None, None], - ) -> Vectorized: - """Vectorize temporal features (no targets included). - - Parameters - ---------- - generator: Generator[pd.DataFrame, None, None] - Generator to yield the saved data files. - - Returns - ------- - Vectorized - Vectorized temporal data. - - """ - vecs = list(generator) - join_axis = vecs[0].get_axis(self.common_feature) - res = np.concatenate([vec.data for vec in vecs], axis=join_axis) - indexes = vecs[0].indexes - indexes[join_axis] = np.concatenate([vec.indexes[join_axis] for vec in vecs]) - temp_vectorized = Vectorized(res, indexes, vecs[0].axis_names) - del res - return temp_vectorized - - def _compute_timestep( - self, - timestamps: pd.DataFrame, - timestamp_col: str, - ) -> pd.DataFrame: - """Compute timestep for a specific timestamp feature. - - Parameters - ---------- - timestamps : pd.DataFrame - The timestamps data. - timestamp_col : str - The timestamp for which the timestep is to be computed. - - Returns - ------- - pd.DataFrame - Timestamps with the new timestep. - - """ - timestep_size = self.temp_params["timestep_size"] - new_timestamp = f"{timestamp_col}_{self.timestep_params['new_timestamp']}" - timestamps[new_timestamp] = ( - timestamps[timestamp_col] - timestamps[self.timestep_params["anchor"]] - ) - - timestep_col = f"{timestamp_col}_timestep" - timestamps[timestep_col] = ( - timestamps[new_timestamp] / pd.Timedelta(f"{timestep_size} hour") - ).apply(np.floor) - return timestamps - - def _create_target( - self, - temp_vectorized: Vectorized, - timestamps: pd.DataFrame, - ) -> np.ndarray: - """Create targets for temporal data based on the window duration. - - Parameters - ---------- - temp_vectorized : Vectorized - Vectorized temporal data. - timestamps : pd.DataFrame - The timestamps data. - - Returns - ------- - np.ndarray - Array of the target values. - - """ - index_order = pd.Series(temp_vectorized.get_index(self.common_feature)) - index_order = index_order.rename(self.common_feature).to_frame() - target_timestamp = self.temp_target_params["target_timestamp"] - target_timestep = "target_timestep" - ref_timestamp = self.temp_target_params["ref_timestamp"] - ref_timestep = f"{ref_timestamp}_timestep" - - timestamps["target"] = timestamps[target_timestamp] - pd.DateOffset( - hours=self.temp_params["predict_offset"], - ) - - timestamps = self._compute_timestep(timestamps, "target") - timestamps = self._compute_timestep( - timestamps, - ref_timestamp, - ) - - timesteps = timestamps[ - [ - self.common_feature, - target_timestep, - ref_timestep, - ] - ] - - aligned_timestamps = pd.merge( - index_order, - timesteps, - on=self.common_feature, - how="left", - ) - - num_timesteps = int( - self.temp_params["window_duration"] / self.temp_params["timestep_size"], - ) - - arr1 = timestamp_ffill_agg( - aligned_timestamps[target_timestep], - num_timesteps, - fill_nan=2, - ) - - arr2 = timestamp_ffill_agg( - aligned_timestamps[ref_timestep], - num_timesteps, - val=-1, - fill_nan=2, - ) - - targets = np.minimum(arr1, arr2) - targets[targets == 2] = 0 - return np.expand_dims(np.expand_dims(targets, 0), 2) - - def _vectorize_temporal( - self, - temp_vectorized: Vectorized, - targets: np.ndarray, - normalize: bool, - ) -> Vectorized: - """Vectorize the tabular data. - - Parameters - ---------- - temp_vectorized : Vectorized - Vectorized temporal features. - targets: np.ndarray - Array of temporal targets. - normalize : bool - Whether to normalize the data. - - Returns - ------- - Vectorized - Vectorized temporal data containing features and targets. - - """ - temp_vectorized = temp_vectorized.concat_over_axis( - self.temp_feature_params["primary_feature"], - targets, - self.temp_feature_params[TARGETS], - ) - - if normalize: - temp_vectorized.add_normalizer( - self.temp_feature_params["primary_feature"], - normalization_method=self.temp_norm_params["method"], - ) - - save_pickle(temp_vectorized, self.temp_vectorized_file) - return temp_vectorized - - def _split_temporal(self, temp_vectorized: Vectorized) -> Tuple: - """Split the temporal data to train, validation, and test sets. - - Parameters - ---------- - temp_vectorized : Vectorized - Vectorized temporal data. - - Returns - ------- - Tuple - A tuple of datasets of splits. All splits are Vectorized objects. - - """ - fractions = self.split_fractions.copy() - temp_train, temp_val, temp_test = split_vectorized( - vecs=[temp_vectorized], - fractions=fractions, - axes=self.common_feature, - )[0] - return temp_train, temp_val, temp_test - - def _get_temp_train( - self, - temp_train: Vectorized, - normalize: bool, - impute: Optional[bool] = True, - ) -> Tuple: - """Get the temporal train features (normalized) and the targets. - - Parameters - ---------- - temp_train : Vectorized - Vectorized temporal data. - normalize : bool - Whether to normalize the data. - impute : bool - Whether to impute values. - - Returns - ------- - Tuple - Tuple of train features and targets. - - """ - temp_train_X, temp_train_y = temp_train.split_out( - self.temp_feature_params["primary_feature"], - self.temp_feature_params[TARGETS], - ) - if impute: - temp_train_X.impute( - self.temp_impute_params["axis"], - self.temp_feature_params["primary_feature"], - self.temp_impute_params["func"], - ) - - if normalize: - temp_train_X = self._normalize(temp_train_X) - - return temp_train_X, temp_train_y - - def _get_temp_val( - self, - temp_val, - normalize: bool, - impute: Optional[bool] = True, - ) -> Tuple: - """Get the temporal validation features (normalized) and the targets. - - Parameters - ---------- - temp_val : Vectorized - Vectorized temporal data. - normalize : bool - Whether to normalize the data. - impute : bool - Whether to impute values. - - Returns - ------- - Tuple - Tuple of validation features and targets. - - """ - temp_val_X, temp_val_y = temp_val.split_out( - self.temp_feature_params["primary_feature"], - self.temp_feature_params[TARGETS], - ) - if impute: - temp_val_X.impute( - self.temp_impute_params["axis"], - self.temp_feature_params["primary_feature"], - self.temp_impute_params["func"], - ) - if normalize: - temp_val_X = self._normalize(temp_val_X) - return temp_val_X, temp_val_y - - def _get_temp_test( - self, - temp_test, - normalize: bool, - impute: Optional[bool] = True, - ) -> Tuple: - """Get the temporal test features (normalized) and the targets. - - Parameters - ---------- - temp_test : Vectorized - Vectorized temporal data. - normalize : bool - Whether to normalize the data. - impute : bool - Whether to impute values. - - Returns - ------- - Tuple - Tuple of test features and targets. - - """ - temp_test_X, temp_test_y = temp_test.split_out( - self.temp_feature_params["primary_feature"], - self.temp_feature_params[TARGETS], - ) - if impute: - temp_test_X.impute( - self.temp_impute_params["axis"], - self.temp_feature_params["primary_feature"], - self.temp_impute_params["func"], - ) - if normalize: - temp_test_X = self._normalize(temp_test_X) - return temp_test_X, temp_test_y - - def _save_temporal( - self, - temp_train_X, - temp_train_y, - temp_val_X, - temp_val_y, - temp_test_X, - temp_test_y, - aligned, - ): - """Save the temporal features and targets for all data splits. - - Parameters - ---------- - temp_train_X : Vectorized - Vectorized temporal features from the train set. - temp_train_y : Vectorized - Vectorized temporal targets from the train set. - temp_val_X : Vectorized - Vectorized temporal features from the validation set. - temp_val_y : Vectorized - Vectorized temporal targets from the validation set. - temp_test_X : Vectorized - Vectorized temporal features from the test set. - temp_test_y : Vectorized - Vectorized temporal targets from the test set. - aligned : bool - Whether data is aligned with the tabular data. - - """ - vectorized = [ - (temp_train_X, "temp_train_X"), - (temp_train_y, "temp_train_y"), - (temp_val_X, "temp_val_X"), - (temp_val_y, "temp_val_y"), - (temp_test_X, "temp_test_X"), - (temp_test_y, "temp_test_y"), - ] - for vec, name in vectorized: - if aligned: - save_pickle(vec, self.aligned_path + name) - else: - save_pickle(vec, self.unaligned_path + name) - - ################### - # Combined methods - ################### - - def _vectorize_combined( - self, - temp_vectorized: Vectorized, - tab_aggregated_vec: Vectorized, - ) -> Vectorized: - """Vectorize the combined data. - - Parameters - ---------- - temp_vectorized : Vectorized - Vectorized temporal data. - tab_aggregated_vec : Vectorized - Vectorized aggregated tabular data. - - Returns - ------- - Vectorized - Vectorized combined data. - - """ - comb_vectorized = temp_vectorized.concat_over_axis( - self.temp_feature_params["primary_feature"], - tab_aggregated_vec.data, - tab_aggregated_vec.get_index(self.temp_feature_params["primary_feature"]), - ) - comb_vectorized, _ = comb_vectorized.split_out( - self.temp_feature_params["primary_feature"], - self.tab_feature_params[TARGETS], - ) - - comb_vectorized.add_normalizer( - self.temp_feature_params["primary_feature"], - normalization_method=self.temp_norm_params["method"], - ) - - save_pickle(comb_vectorized, self.comb_vectorized_file) - return comb_vectorized - - def _get_intersect_vec( - self, - tab_vectorized: Vectorized, - temp_vectorized: Vectorized, - comb_vectorized: Vectorized, - ) -> Tuple: - """Get the records that are available in all datasets. - - Parameters - ---------- - tab_vectorized : Vectorized - Vectorized tabular data. - temp_vectorized : Vectorized - Vectorized temporal data. - comb_vectorized : Vectorized - Vectorized combined data. - - Returns - ------- - Tuple - Vectorized tabular, temporal, and combined data. - - """ - tab_vectorized, temp_vectorized, comb_vectorized = intersect_vectorized( - [tab_vectorized, temp_vectorized, comb_vectorized], - axes=self.common_feature, - ) - - return tab_vectorized, temp_vectorized, comb_vectorized - - def _split_combined(self, comb_vectorized: Vectorized) -> Tuple: - """Split combined data to train, validation, and test sets. - - Parameters - ---------- - tab_vectorized : Vectorized - Vectorized combined data. - - Returns - ------- - Tuple - A tuple of datasets of splits. All splits are Vectorized objects. - - """ - fractions = self.split_fractions.copy() - comb_train, comb_val, comb_test = split_vectorized( - vecs=[comb_vectorized], - fractions=fractions, - axes=self.common_feature, - )[0] - return comb_train, comb_val, comb_test - - def _get_comb_train( - self, - comb_train: Vectorized, - normalize: bool, - impute: Optional[bool] = True, - ) -> Tuple: - """Get combined train features (normalized) and the targets. - - Parameters - ---------- - comb_train : Vectorized - Vectorized combined data. - normalize : bool - Whether to normalize the data. - impute : bool - Whether to impute values. - - Returns - ------- - Tuple - Tuple of train features and targets. - - """ - comb_train_X, comb_train_y = comb_train.split_out( - self.temp_feature_params["primary_feature"], - self.temp_feature_params[TARGETS], - ) - if impute: - comb_train_X.impute( - self.temp_impute_params["axis"], - self.temp_feature_params["primary_feature"], - self.temp_impute_params["func"], - ) - - if normalize: - comb_train_X = self._normalize(comb_train_X) - - return comb_train_X, comb_train_y - - def _get_comb_val( - self, - comb_val: Vectorized, - normalize: bool, - impute: Optional[bool] = True, - ) -> Tuple: - """Get combined validation features (normalized) and the targets. - - Parameters - ---------- - comb_validation : Vectorized - Vectorized combined data. - normalize : bool - Whether to normalize the data. - impute : bool - Whether to impute values. - - Returns - ------- - Tuple - Tuple of validation features and targets. - - """ - comb_val_X, comb_val_y = comb_val.split_out( - self.temp_feature_params["primary_feature"], - self.temp_feature_params[TARGETS], - ) - if impute: - comb_val_X.impute( - self.temp_impute_params["axis"], - self.temp_feature_params["primary_feature"], - self.temp_impute_params["func"], - ) - - if normalize: - comb_val_X = self._normalize(comb_val_X) - - return comb_val_X, comb_val_y - - def _get_comb_test( - self, - comb_test: Vectorized, - normalize: bool, - impute: Optional[bool] = True, - ) -> Tuple: - """Get combined test features (normalized) and target. - - Parameters - ---------- - comb_test : Vectorized - Vectorized combined data. - normalize : bool - Whether to normalize the data. - impute : bool - Whether to impute values. - - Returns - ------- - Tuple - Tuple of test features and targets. - - """ - comb_test_X, comb_test_y = comb_test.split_out( - self.temp_feature_params["primary_feature"], - self.temp_feature_params[TARGETS], - ) - if impute: - comb_test_X.impute( - self.temp_impute_params["axis"], - self.temp_feature_params["primary_feature"], - self.temp_impute_params["func"], - ) - - if normalize: - comb_test_X = self._normalize(comb_test_X) - - return comb_test_X, comb_test_y - - def _save_combined( - self, - comb_train_X, - comb_train_y, - comb_val_X, - comb_val_y, - comb_test_X, - comb_test_y, - ): - """Save combined features and targets for all data splits. - - Parameters - ---------- - comb_train_X : Vectorized - Vectorized combined features from the train set. - comb_train_y : Vectorized - Vectorized combined targets from the train set. - comb_val_X : Vectorized - Vectorized combined features from the validation set. - comb_val_y : Vectorized - Vectorized combined targets from the validation set. - comb_test_X : Vectorized - Vectorized combined features from the test set. - comb_test_y : Vectorized - Vectorized combined targets from the test set. - - """ - vectorized = [ - (comb_train_X, "comb_train_X"), - (comb_train_y, "comb_train_y"), - (comb_val_X, "comb_val_X"), - (comb_val_y, "comb_val_y"), - (comb_test_X, "comb_test_X"), - (comb_test_y, "comb_test_y"), - ] - for vec, name in vectorized: - save_pickle(vec, self.aligned_path + name) - - ################### - # Process methods - ################### - - def process_tabular_one(self) -> Tuple: - """First step of tabular processing. - - 1. Load data. - 2. Get tabular features as an object. - 3. Slice the data if required. - 4. Vectorize. - - Returns - ------- - Tuple - Tuple of vectorized tabular data and tabular features object. - - """ - LOGGER.info("Loading the tabular data.") - cohort = self._load_cohort().reset_index(drop=True) - tab_features = self._get_tabular_features(cohort) - - if self.tab_slice_params["slice"]: - LOGGER.info("Slicing the tabular data.") - _ = self._slice_tabular(tab_features) - - LOGGER.info("Vectorizing the tabular data.") - tab_vectorized = self._vectorize_tabular( - tab_features, - self.tab_norm_params["normalize"], - ) - return tab_vectorized, tab_features - - def process_tabular_two(self, tab_vectorized: Vectorized, aligned: bool) -> None: - """Second step of tabular processing. - - 1. Split. - 2. Get the features and targets for each split. - 3. Save the finalized vectorized data. - - Parameters - ---------- - tab_vectorized : Vectorized - Vectorized tabular data. - aligned : bool - Whether data is aligned with the temporal data. - - """ - LOGGER.info("Splitting the tabular data.") - tab_train, tab_val, tab_test = self._split_tabular(tab_vectorized) - - tab_train_X, tab_train_y = self._get_tab_train( - tab_train, - self.tab_norm_params["normalize"], - ) - tab_val_X, tab_val_y = self._get_tab_val( - tab_val, - self.tab_norm_params["normalize"], - ) - tab_test_X, tab_test_y = self._get_tab_test( - tab_test, - self.tab_norm_params["normalize"], - ) - - LOGGER.info("Saving the tabular features and targets for all data splits.") - self._save_tabular( - tab_train_X, - tab_train_y, - tab_val_X, - tab_val_y, - tab_test_X, - tab_test_y, - aligned, - ) - - def process_tabular(self) -> None: - """Process tabular data.""" - tab_vectorized, _ = self.process_tabular_one() - self.process_tabular_two(tab_vectorized, aligned=False) - - def process_temporal_one(self) -> Vectorized: - """First step of temporal processing. - - 1. Aggregate temporal data. - 2. Vectorize temporal features. - 3. Create targets. - 4. Vectorize the whole temporal data. - - Returns - ------- - Vectorized - Vectorized temporal data. - - """ - cleaned_generator = self._load_batches(self.cleaned_dir) - filter_fn = None - if ( - self.temp_params["query"] == mimic.CHARTEVENTS - and self.temp_params["top_n_events"] - ): - LOGGER.info("Getting top %d events", self.temp_params["top_n_events"]) - top_events = get_top_events( - self.cleaned_dir, - self.temp_params["top_n_events"], - ) - filter_fn = lambda events: valid_events(events, top_events) # noqa: E731 - - LOGGER.info("Aggregating the temporal features in batches.") - self._aggregate_temporal_batches(cleaned_generator, filter_fn) - - LOGGER.info("Vectorizing the temporal features in batches.") - agg_generator = self._load_batches(self.aggregated_dir) - self._vectorize_temporal_batches(agg_generator) - - vec_generator = yield_pickled_files(self.vectorized_dir) - temp_vectorized = self._vectorize_temporal_features(vec_generator) - - LOGGER.info("Creating the temporal targets.") - timestamps = self._get_timestamps() - targets = self._create_target(temp_vectorized, timestamps) - - LOGGER.info("Vectorizing the temporal data.") - return self._vectorize_temporal( - temp_vectorized, - targets, - self.temp_norm_params["normalize"], - ) - - def process_temporal_two(self, temp_vectorized: Vectorized, aligned: bool) -> None: - """Second step of temporal processing. - - 1. Split. - 2. Get the features and targets for each split. - 3. Save the finalized vectorized data. - - Parameters - ---------- - temp_vectorized : Vectorized - Vectorized temporal data. - aligned : bool - Whether the data is aligned with the tabular data. - - """ - LOGGER.info("Splitting the temporal data.") - temp_train, temp_val, temp_test = self._split_temporal(temp_vectorized) - temp_train_X, temp_train_y = self._get_temp_train( - temp_train, - self.temp_norm_params["normalize"], - ) - temp_val_X, temp_val_y = self._get_temp_val( - temp_val, - self.temp_norm_params["normalize"], - ) - temp_test_X, temp_test_y = self._get_temp_test( - temp_test, - self.temp_norm_params["normalize"], - ) - - LOGGER.info("Saving the temporal features and targets for data splits.") - self._save_temporal( - temp_train_X, - temp_train_y, - temp_val_X, - temp_val_y, - temp_test_X, - temp_test_y, - aligned, - ) - - def process_temporal(self) -> None: - """Process temporal data.""" - temp_vectorized = self.process_temporal_one() - self.process_temporal_two(temp_vectorized, aligned=False) - - def process_combined_one(self) -> Tuple: - """First step of combined processing. - - 1. Process tabular data or load from a file. - 2. Process temporal data or load from a file. - 3. Aggregate tabular data. - 4. Vectorize aggregated tabular data. - 5. Vectorize the combined data. - 6. Get the intersection of the three datasets. - - Returns - ------- - Tuple - Vectorized tabular, temporal, and combined data. - - """ - LOGGER.info("Getting the vectorized tabular data.") - if path.exists(self.tab_vectorized_file): - tab_vectorized = load_pickle(self.tab_vectorized_file) - tab_features = load_pickle(self.tab_features_file) - else: - tab_vectorized, tab_features = self.process_tabular_one() - - LOGGER.info("Getting the vectorized temporal data.") - if path.exists(self.temp_vectorized_file): - temp_vectorized = load_pickle(self.temp_vectorized_file) - else: - temp_vectorized = self.process_temporal_one() - - LOGGER.info("Combining tabular and temporal data.") - tab_aggregated = self._aggregate_tabular(tab_features, temp_vectorized) - tab_aggregated_vec = self._vectorize_agg_tabular(tab_aggregated) - comb_vectorized = self._vectorize_combined(temp_vectorized, tab_aggregated_vec) - tab_vectorized, temp_vectorized, comb_vectorized = self._get_intersect_vec( - tab_vectorized, - temp_vectorized, - comb_vectorized, - ) - return tab_vectorized, temp_vectorized, comb_vectorized - - def process_combined_two(self, comb_vectorized: Vectorized) -> None: - """Second step of combined processing. - - 1. Split. - 2. Get the features and targets for each split. - 3. Save the finalized vectorized data. - - Parameters - ---------- - comb_vectorized : Vectorized - Vectorized combined data. - - """ - LOGGER.info("Splitting the combined data.") - comb_train, comb_val, comb_test = self._split_combined(comb_vectorized) - comb_train_X, comb_train_y = self._get_comb_train( - comb_train, - self.temp_norm_params["normalize"], - ) - comb_val_X, comb_val_y = self._get_comb_val( - comb_val, - self.temp_norm_params["normalize"], - ) - comb_test_X, comb_test_y = self._get_comb_test( - comb_test, - self.temp_norm_params["normalize"], - ) - - LOGGER.info("Saving the combined features and targets for all data splits.") - self._save_combined( - comb_train_X, - comb_train_y, - comb_val_X, - comb_val_y, - comb_test_X, - comb_test_y, - ) - - def process_combined(self) -> None: - """Process combined data.""" - tab_vectorized, temp_vectorized, comb_vectorized = self.process_combined_one() - self.process_tabular_two(tab_vectorized, aligned=True) - self.process_temporal_two(temp_vectorized, aligned=True) - self.process_combined_two(comb_vectorized) diff --git a/use_cases/data_processors/process.py b/use_cases/data_processors/process.py deleted file mode 100644 index 3fca4699a..000000000 --- a/use_cases/data_processors/process.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Data Processor class to process data per use-case.""" - -from cyclops.models.constants import DATA_TYPES, DATASETS, USE_CASES -from use_cases.data_processors.mimiciv import MIMICIVProcessor - - -class DataProcessor: - """Data processor class.""" - - def __init__( - self, - dataset_name: str, - use_case: str, - data_type: str, - ) -> None: - """Initialize processor. - - Parameters - ---------- - dataset_name : str - Dataset name to process the data from. - use_case : str - Use-case to process the data for. - data_type : str - Type of data (tabular, temporal, or combined). - - """ - self.dataset_name = dataset_name.lower() - self.use_case = use_case.lower() - self.data_type = data_type.lower() - - self._validate() - self._init_processor() - - def _validate(self) -> None: - """Validate the input arguments.""" - assert self.dataset_name in DATASETS, "[!] Invalid dataset name" - assert self.use_case in USE_CASES, "[!] Invalid use case" - assert ( - self.dataset_name in USE_CASES[self.use_case] - ), "[!] Unsupported use case for this dataset" - assert self.data_type in DATA_TYPES, "[!] Invalid data type" - - def _init_processor(self) -> None: - """Initialize the processor based on the dataset name and the use-case.""" - if self.dataset_name == "mimiciv": - self.processor = MIMICIVProcessor(self.use_case, self.data_type) - - def process_data(self): - """Process the data based on its type.""" - if self.data_type == "tabular": - self.processor.process_tabular() - elif self.data_type == "temporal": - self.processor.process_temporal() - else: - self.processor.process_combined() diff --git a/use_cases/examples/mimiciv/mortality_decompensation.ipynb b/use_cases/examples/mimiciv/mortality_decompensation.ipynb deleted file mode 100644 index 42e212b4f..000000000 --- a/use_cases/examples/mimiciv/mortality_decompensation.ipynb +++ /dev/null @@ -1,433 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Predicting mortality decompensation on MIMICIV\n", - "\n", - "This notebook presents examples of preprocessing data, training, and testing models to predict mortality decompensation on MIMICIV dataset. There are three types of processing and modeling based on the data type:\n", - "1. Tabular: Using tabular data and applying static models.\n", - "2. Temporal: Using temporal data as timeseries and applying temporal models.\n", - "3. Combined: Using both tabular and temporal data and applying temporal models." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from cyclops.models.catalog import list_models\n", - "from cyclops.models.predictor import Predictor\n", - "from use_cases.data_processors.process import DataProcessor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DATASET = \"mimiciv\"\n", - "USE_CASE = \"mortality_decompensation\"\n", - "TABULAR_TYPE = \"tabular\"\n", - "TEMPORAL_TYPE = \"temporal\"\n", - "COMBINED_TYPE = \"combined\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Tabular Processing\n", - "\n", - "Tabular processing aims to load the queried data, prepares the data for training and testing, and saves the finalized data splits as vectorized objects.\n", - "\n", - "If the processed data already exists, this step can be skipped." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tabular_processor = DataProcessor(DATASET, USE_CASE, TABULAR_TYPE)\n", - "tabular_processor.process_data()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Tabular modeling\n", - "\n", - "Tabular modeling aims to train a static model on tabular data to predict mortality decompensation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# List of supported static models\n", - "list_models(\"static\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Initializing the predictor\n", - "tabular_predictor = Predictor(\"xgb_classifier\", DATASET, USE_CASE, TABULAR_TYPE)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tabular_predictor.dataset.n_features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tabular_predictor.dataset.X_train.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tabular_predictor.dataset.y_train.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tabular_predictor.model.get_params()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Training the model\n", - "tabular_model = tabular_predictor.fit()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Predicting on the test set\n", - "y_pred = tabular_predictor.predict()\n", - "y_pred" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Getting evaluation metrics\n", - "tabular_predictor.evaluate(verbose=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Temporal processing\n", - "\n", - "Temporal processing aims to load the queried data, prepares the timeseries data for training and testing, and saves the finalized data splits as vectorized objects.\n", - "\n", - "If the processed data already exists, this step can be skipped." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "temporal_processor = DataProcessor(DATASET, USE_CASE, TEMPORAL_TYPE)\n", - "temporal_processor.process_data()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Temporal modeling\n", - "\n", - "Temporal modeling aims to train a temporal model on timeseries data to predict mortality decompensation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# List of supported temopral models\n", - "list_models(\"temporal\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Initializing the predictor\n", - "temporal_predictor = Predictor(\"gru\", DATASET, USE_CASE, TEMPORAL_TYPE)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "temporal_predictor.dataset.n_features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "temporal_predictor.dataset.X_train.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "temporal_predictor.dataset.y_train.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "temporal_predictor.model.get_params()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Training the model\n", - "temporal_model = temporal_predictor.fit()\n", - "temporal_model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Predicting on the test set\n", - "y_test_labels, y_pred_values, y_pred_labels = temporal_predictor.predict(temporal_model)\n", - "y_pred_values" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Getting evaluation metrics\n", - "temporal_predictor.evaluate(temporal_model, verbose=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Combined processing\n", - "\n", - "Combined processing aims to load both tabular and temporal data, and combines them for training and testing, and saves the finalized data splits as vectorized objects.\n", - "\n", - "If the processed data already exists, this step can be skipped." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "combined_processor = DataProcessor(DATASET, USE_CASE, COMBINED_TYPE)\n", - "combined_processor.process_data()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Combined modeling\n", - "\n", - "Combined modeling aims to train a temporal model on timeseries data integrated with tabular data to predict mortality decompensation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# List of supported temopral models\n", - "list_models(\"temporal\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Initializing the predictor\n", - "combined_predictor = Predictor(\"lstm\", DATASET, USE_CASE, COMBINED_TYPE)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "combined_predictor.dataset.n_features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "combined_predictor.dataset.X_train.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "combined_predictor.dataset.y_train.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "combined_predictor.model.get_params()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Training the model\n", - "combined_model = combined_predictor.fit()\n", - "combined_model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Predicting on the test set\n", - "y_test_labels, y_pred_values, y_pred_labels = combined_predictor.predict(combined_model)\n", - "y_pred_values" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Getting evaluation metrics\n", - "combined_predictor.evaluate(combined_model, verbose=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pycyclops-vJuqw9Rd-py3.9", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.12" - }, - "vscode": { - "interpreter": { - "hash": "c3ca27156ee2f087c6753b1a8bfdb2423cbc0389ae963b080434b18581866fbd" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/use_cases/params/mimiciv/mortality_decompensation/constants.py b/use_cases/params/mimiciv/mortality_decompensation/constants.py deleted file mode 100644 index 8ddda46b4..000000000 --- a/use_cases/params/mimiciv/mortality_decompensation/constants.py +++ /dev/null @@ -1,141 +0,0 @@ -"""MIMICIV parameters for mortality decompensation data processing.""" - -from cyclops.process.column_names import ( - ADMIT_TIMESTAMP, - AGE, - DISCHARGE_TIMESTAMP, - ENCOUNTER_ID, - EVENT_NAME, - EVENT_TIMESTAMP, - EVENT_VALUE, - RESTRICT_TIMESTAMP, - SEX, - TIMESTEP, -) -from cyclops.process.constants import ALL, MEAN, STANDARD, TARGETS -from cyclops.process.impute import np_ffill_bfill -from cyclops.utils.file import join, process_dir_save_path - - -CONST_NAME = "mortality_decompensation" -OUTCOME_DEATH = "outcome_death" -DEATHTIME = "deathtime" - -# PATHS -USECASE_ROOT_DIR = join( - "/mnt/data", - "cyclops", - "use_cases", - "mimiciv", - CONST_NAME, -) -DATA_DIR = process_dir_save_path(join(USECASE_ROOT_DIR, "data")) - -QUERIED_DIR = process_dir_save_path(join(DATA_DIR, "0-queried")) -CLEANED_DIR = process_dir_save_path(join(DATA_DIR, "1-cleaned")) -AGGREGATED_DIR = process_dir_save_path(join(DATA_DIR, "2-agg")) -VECTORIZED_DIR = process_dir_save_path(join(DATA_DIR, "3-vec")) -FINAL_VECTORIZED = process_dir_save_path(join(DATA_DIR, "4-final")) - -TABULAR_FILE = join(DATA_DIR, "encounters.parquet") -TAB_FEATURES_FILE = join(DATA_DIR, "tab_features.pkl") -TAB_SLICE_FILE = join(DATA_DIR, "tab_slice.pkl") -TAB_AGGREGATED_FILE = join(DATA_DIR, "tab_aggregated.parquet") -TAB_VECTORIZED_FILE = join(DATA_DIR, "tab_vectorized.pkl") -TEMP_AGGREGATED_FILE = join(DATA_DIR, "temp_aggregated.parquet") -TEMP_VECTORIZED_FILE = join(DATA_DIR, "temp_vectorized.pkl") -COMB_VECTORIZED_FILE = join(DATA_DIR, "comb_vectorized.pkl") - -ALIGNED_PATH = join(FINAL_VECTORIZED, "aligned_") -UNALIGNED_PATH = join(FINAL_VECTORIZED, "unaligned_") - -# PARAMS -COMMON_FEATURE = ENCOUNTER_ID -SKIP_N = 0 -SPLIT_FRACTIONS = [0.8, 0.1, 0.1] - -TABULAR_FEATURES = { - "primary_feature": ENCOUNTER_ID, - "outcome": OUTCOME_DEATH, - "targets": [OUTCOME_DEATH], - "features": [AGE, SEX, "admission_type", "admission_location", OUTCOME_DEATH], - "features_types": {}, -} - -TABULAR_NORM = { - "normalize": True, - "method": STANDARD, -} - -TABULAR_SLICE = { - "slice": False, - "slice_map": {AGE: 80}, - "slice_query": "", - "replace": True, -} - -TABULAR_AGG = { - "strategy": ALL, - "index": ENCOUNTER_ID, - "var_name": EVENT_NAME, - "value_name": EVENT_VALUE, -} - -TEMPORAL_PARAMS = { - "query": "chartevents", - "top_n_events": 150, - "timestep_size": 24, # Make a prediction every day - "window_duration": 144, # Predict for the first 6 days of admission - "predict_offset": 336, # Death occurs in the next 2 weeks -} - -TEMPORAL_NORM = { - "normalize": True, - "method": STANDARD, -} - -TEMPORAL_FEATURES = { - "primary_feature": EVENT_NAME, - "features": [EVENT_VALUE], - "groupby": [ENCOUNTER_ID, EVENT_NAME], - "timestamp_col": EVENT_TIMESTAMP, - "outcome": TARGETS + " - " + OUTCOME_DEATH, - "targets": [TARGETS + " - " + OUTCOME_DEATH], -} - - -TIMESTAMPS = { - "use_tabular": True, - "columns": [ENCOUNTER_ID, ADMIT_TIMESTAMP, DISCHARGE_TIMESTAMP, DEATHTIME], - "start_columns": [ENCOUNTER_ID, ADMIT_TIMESTAMP], - "start_index": ENCOUNTER_ID, - "rename": {ADMIT_TIMESTAMP: RESTRICT_TIMESTAMP}, -} - -TIMESTEPS = {"new_timestamp": "after_admit", "anchor": ADMIT_TIMESTAMP} - -TEMPORAL_TARGETS = { - "target_timestamp": DEATHTIME, - "ref_timestamp": DISCHARGE_TIMESTAMP, -} - -TEMPORAL_SLICE = { - "slice": False, - "slice_map": {}, - "replace": False, -} - -TEMPORAL_AGG = { - "aggfuncs": {EVENT_VALUE: MEAN}, - "timestamp_col": EVENT_TIMESTAMP, - "time_by": ENCOUNTER_ID, - "agg_by": [ENCOUNTER_ID, EVENT_NAME], - "timestep_size": 24, - "window_duration": 144, -} - -TEMPORAL_IMPUTE = { - "impute": True, - "axis": TIMESTEP, - "func": np_ffill_bfill, -} diff --git a/use_cases/params/mimiciv/mortality_decompensation/constants_v1.py b/use_cases/params/mimiciv/mortality_decompensation/constants_v1.py deleted file mode 100644 index 412cd007b..000000000 --- a/use_cases/params/mimiciv/mortality_decompensation/constants_v1.py +++ /dev/null @@ -1,57 +0,0 @@ -"""MIMICIV mortality decompensation use case constants.""" - -from cyclops.process.column_names import AGE, SEX -from cyclops.process.constants import TARGETS -from cyclops.utils.file import join, process_dir_save_path - - -CONST_NAME = "mortality_decompensation" -USECASE_ROOT_DIR = join( - "/mnt/data", - "cyclops", - "use_cases", - "mimiciv", - CONST_NAME, -) -DATA_DIR = process_dir_save_path(join(USECASE_ROOT_DIR, "./data")) - -OUTCOME_DEATH = "outcome_death" -SPLIT_FRACTIONS = [0.8, 0.1, 0.1] - -ENCOUNTERS_FILE = join(DATA_DIR, "encounters.parquet") -AGGREGATED_FILE = join(DATA_DIR, "aggregated.parquet") -TAB_FEATURES_FILE = join(DATA_DIR, "tab_features.pkl") -TAB_VECTORIZED_FILE = join(DATA_DIR, "tab_vectorized.pkl") -TEMP_VECTORIZED_FILE = join(DATA_DIR, "temp_vectorized.pkl") - -# Tabular -TAB_TARGETS = [OUTCOME_DEATH] -TAB_FEATURES = [ - AGE, - SEX, - "admission_type", - "admission_location", -] + TAB_TARGETS -TAB_FEATURES_TYPES: dict = {} - -# Temporal -TIMESTEP_SIZE = 24 # Make a prediction every day -WINDOW_DURATION = 144 # Predict for the first 6 days of admission -PREDICT_OFFSET = 24 * 14 # Death in occurs in the next 2 weeks - -TOP_N_EVENTS = 150 - -OUTCOME_DEATH_TEMP = TARGETS + " - " + OUTCOME_DEATH -TEMP_TARGETS = [OUTCOME_DEATH_TEMP] - -TARGET_TIMESTAMP = "deathtime" - -QUERIED_DIR = process_dir_save_path(join(DATA_DIR, "0-queried")) -CLEANED_DIR = process_dir_save_path(join(DATA_DIR, "1-cleaned")) -AGGREGATED_DIR = process_dir_save_path(join(DATA_DIR, "2-agg")) -VECTORIZED_DIR = process_dir_save_path(join(DATA_DIR, "3-vec")) - -# Saving final vectorized -FINAL_VECTORIZED = process_dir_save_path(join(DATA_DIR, "4-final")) -TAB_UNALIGNED = join(FINAL_VECTORIZED, "unaligned_") -TAB_VEC_COMB = join(FINAL_VECTORIZED, "aligned_") diff --git a/use_cases/util.py b/use_cases/util.py deleted file mode 100644 index e18d8f018..000000000 --- a/use_cases/util.py +++ /dev/null @@ -1,154 +0,0 @@ -"""Utility functions shared across use-cases.""" - -import importlib -import types -from functools import reduce -from typing import Dict, List, Mapping, Optional, Tuple, Union - -import numpy as np -import pandas as pd -from datasets import Dataset, DatasetDict - -from cyclops.data.utils import is_out_of_core -from cyclops.process.column_names import EVENT_NAME, EVENT_VALUE -from cyclops.utils.file import yield_dataframes - - -def get_use_case_params(dataset: str, use_case: str) -> types.ModuleType: - """Import parameters specific to each use-case. - - Parameters - ---------- - dataset: str - Name of the dataset, e.g. mimiciv. - use_case: str - Name of the use-case, e.g. mortality_decompensation. - - Returns - ------- - types.ModuleType - Imported constants module with use-case parameters. - - """ - return importlib.import_module( - ".".join(["use_cases", "params", dataset, use_case, "constants"]), - ) - - -def get_top_events(events_path: str, n_events: int) -> np.ndarray: - """Get top events from events data saved in batches. - - Parameters - ---------- - events_path : str - Path to the directory of saved events. - n_events : int - The number of top events. - - Returns - ------- - np.ndarray - The array of the top events names. - - """ - all_top_events = [] - for _, events in enumerate(yield_dataframes(events_path, log=False)): - top_events = ( - events[EVENT_NAME][~events[EVENT_VALUE].isna()] - .value_counts()[:n_events] - .index - ) - - all_top_events.append(top_events) - - del events - - # Take only the events common to every file - top_events = reduce(np.intersect1d, tuple(all_top_events)) - return sorted(top_events) - - -def valid_events(events: pd.DataFrame, top_events: np.ndarray) -> pd.DataFrame: - """Keep the events that are included in the top events. - - Parameters - ---------- - events : pd.DataFrame - The events dataframe. - top_events : np.ndarray - The list of top events. - - Returns - ------- - pd.DataFrame - The events dataframe including only top events. - - """ - return events[events[EVENT_NAME].isin(top_events)] - - -def get_pandas_df( - dataset: Union[Dataset, DatasetDict, Mapping], - feature_cols: Optional[List[str]] = None, - label_cols: Optional[str] = None, -) -> Union[Tuple[pd.DataFrame, pd.Series], Dict[str, Tuple[pd.DataFrame, pd.Series]]]: - """Convert dataset to pandas dataframe. - - NOTE: converting to pandas does not work with IterableDataset/IterableDatasetDict - (i.e. when dataset is loaded with stream=True). So, this function should only be - used with datasets that are loaded with stream=False and are small enough to fit - in memory. Use :func:`is_out_of_core` to check if dataset is too large to fit in - memory. - - - Parameters - ---------- - dataset : Union[Dataset, DatasetDict, Mapping] - Dataset to convert to pandas dataframe. - feature_cols : List[str], optional - List of feature columns to include in the dataframe, by default None - label_cols : str, optional - Label column to include in the dataframe, by default None - - Returns - ------- - Union[Tuple[pd.DataFrame, pd.Series], Dict[str, Tuple[pd.DataFrame, pd.Series]]] - Pandas dataframe or dictionary of pandas dataframes. - - Raises - ------ - TypeError - If dataset is not a Dataset, DatasetDict, or Mapping. - - """ - if isinstance(dataset, (DatasetDict, Mapping)): - return { - k: get_pandas_df(v, feature_cols=feature_cols, label_cols=label_cols) - for k, v in dataset.items() - } - if isinstance(dataset, Dataset) and not is_out_of_core(dataset.dataset_size): - # validate feature_cols and label_col - if feature_cols is not None and not set(feature_cols).issubset( - dataset.column_names, - ): - raise ValueError("feature_cols must be a subset of dataset column names.") - if label_cols is not None and not set(label_cols).issubset( - dataset.column_names, - ): - raise ValueError("label_col must be a column name of dataset.") - - df = dataset.to_pandas(batched=False) # set batched=True for large datasets - - if feature_cols is not None and label_cols is not None: - pd_dataset = (df[feature_cols], df[label_cols]) - elif label_cols is not None: - pd_dataset = (df.drop(label_cols, axis=1), df[label_cols]) - elif feature_cols is not None: - pd_dataset = (df[feature_cols], None) - else: - pd_dataset = (df, None) - return pd_dataset - - raise TypeError( - f"Expected dataset to be a Dataset or DatasetDict. Got: {type(dataset)}", - )