This repository contains the Pytorch code to replicate experiments in our paper Enabling Counterfactual Survival Analysis with Balanced Representations accepted at ACM Conference on Health, Inference, and Learning (ACM CHIL) 2021:
@inproceedings{chapfuwa2021enabling,
title={Enabling Counterfactual Survival Analysis with Balanced Representations},
author={Chapfuwa, Paidamoyo and Assaad, Serge and Zeng, Shuxi and Pencina, Michael J and Carin, Lawrence and Henao, Ricardo},
booktitle={ACM Conference on Health, Inference, and Learning},
year={2021}
}
- ACTG: A longitudinal RCT study comparing monotherapy with Zidovudine or Didanosine with combination therapy in HIV patients
- Framingham: A subset (Framingham Offspring) of the longitudinal study of heart disease dataset, for predicting the effects of statins on survival time
- See actg_synthetic.ipynb to modify the generated ACTG-Synthetic data
The code is implemented with the following dependencies:
- Python 3.6.4
- Torch 0.4.1
- Additional python packages can be installed by running:
pip install -r requirements.txt
- To train all models (CSA, CSA-INFO, AFT, AFT-Weibull, SR) run param_search.py
python param_search.py
- The hyper-parameters settings can be found at .config/configs.txt
Once the networks are trained and the results are saved, we extract the following key results:
- Training and evaluation metrics are logged in model_*.log
- Epoch based cost function plots can be found in the plots directory
- Numpy files to generate calibration and cluster plots are saved in matrix directory
- To select best validation alpha compute likelihoods according to evaluation_alpha_selection.ipynb
- Run the evaluation_ITE.ipynb to generate factual and causal metrics and evaluation_stratify_HR.ipynb for HR risk stratification
This work leverages the calibration framework from SFM and the accuracy objective from DATE. Contact Paidamoyo for issues relevant to this project.