This repository contains the TensorFlow code to replicate experiments in our paper Survival Cluster Analysis accepted at ACM Conference on Health, Inference, and Learning (ACM CHIL) 2020:
@inproceedings{chapfuwa2020survival,
title={Survival Cluster Analysis},
author={Paidamoyo Chapfuwa and Chunyuan Li and Nikhil Mehta and Lawrence Carin and Ricardo Henao},
booktitle={ACM Conference on Health, Inference, and Learning},
year={2020}
}
Illustration of Survival Clustering Analysis (SCA).
The latent space has a mixture-of-distributions structure, illustrated
as three mixture components. Observation
x
is mapped into its latent representation z
via a deterministic
encoding, which is then used to
stochastically predict (via sampling) the time-to-event p(t|x)
.
Cluster-specific Kaplan-Meier survival profiles for three clustering methods on the SLEEP dataset. Our model (SCA) can identify high-, medium- and low-risk individuals. Demonstrating the need to account for time information via a non-linear transformation of covariates when clustering survival datasets.
The code is implemented with the following dependencies:
- Python 3.6.4
- TensorFlow 1.8.0
- Additional python packages can be installed by running:
pip install -r requirements.txt
We consider the following datasets:
- SUPPORT
- Flchain
- SEER
- SLEEP: A subset of the Sleep Heart Health Study (SHHS), a multi-center cohort study implemented by the National Heart Lung & Blood Institute to determine the cardiovascular and other consequences of sleep-disordered breathing.
- Framingham: A subset (Framingham Offspring) of the longitudinal study of heart disease dataset, initially for predicting 10-year risk for future coronary heart disease (CHD).
- EHR: A large study from Duke University Health System centered around inpatient visits due to comorbidities in patients with Type-2 diabetes.
For convenience, we provide pre-processing scripts of all datasets (except EHR and Framingham). In addition, the data directory contains downloaded Flchain and SUPPORT datasets.
Please modify the train arguments with the chosen:
dataset
is set to one of the three public datasets{flchain, support, seer, sleep}
, the default issupport
- K cluster uppper bound
n_clusters
, the default is25
- Dirichlet process concetration parameter
gamma_0
selected from{2, 3, 4, 8}
, default is2
- To train SCA run train.py
python train.py --dataset support --n_clusters 25 --gamma_0 2
- The hyper-parameters settings can be found at configs.py
Once the networks are trained and the results are saved, we extract the following key results:
- Training and evaluation metrics are logged in model.log
- Epoch based cost function plots can be found in the plots directory
- Numpy files to generate calibration and cluster plots are saved in matrix directory
- Run the Calibration.ipynb to generate calibration results and Clustering.ipynb for clustering results
This work leverages the calibration framework from SFM and the accuracy objective from DATE. Contact Paidamoyo for issues relevant to this project.