Skip to content

Commit

Permalink
Merge pull request #36 from loft-br/feat/refactor-extrapolation
Browse files Browse the repository at this point in the history
Feat/refactor extrapolation
  • Loading branch information
GabrielGimenez authored Jun 10, 2021
2 parents 1ad72c4 + c7e0010 commit fa26300
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 2,544 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ from xgbse.extrapolation import extrapolate_constant_risk
survival = bootstrap_estimator.predict(X_valid)

# extrapolating
survival_ext = extrapolate_constant_risk(survival, 450, 11)
survival_ext = extrapolate_constant_risk(survival, 450, 15)
```

<img src="img/extrapolation.png">
Expand Down Expand Up @@ -407,7 +407,7 @@ To cite this repository:
author = {Davi Vieira and Gabriel Gimenez and Guilherme Marmerola and Vitor Estima},
title = {XGBoost Survival Embeddings: improving statistical properties of XGBoost survival analysis implementation},
url = {http://github.com/loft-br/xgboost-survival-embeddings},
version = {0.2.1},
version = {0.2.2},
year = {2021},
}
```
2 changes: 1 addition & 1 deletion docs/examples/extrapolation_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ Notice that this predicted survival curve does not end at zero (cure fraction du
from xgbse.extrapolation import extrapolate_constant_risk

# extrapolating predicted survival
survival_ext = extrapolate_constant_risk(survival, 450, 11)
survival_ext = extrapolate_constant_risk(survival, 450, 15)
survival_ext.head()
```

Expand Down
2,583 changes: 75 additions & 2,508 deletions examples/extrapolation_example.ipynb

Large diffs are not rendered by default.

Binary file modified img/extrapolation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

setuptools.setup(
name="xgbse",
version="0.2.1",
version="0.2.2",
author="Loft Data Science Team",
author_email="bandits@loft.com.br",
description="Improving XGBoost survival analysis with embeddings and debiased estimators",
Expand Down
6 changes: 4 additions & 2 deletions tests/test_extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
)

preds = xgbse_model.predict(X_test)
interval = 10
final_time = max(time_bins) + 1000
n_windows = 100
final_time = max(T_train) + 1000
preds_ext = extrapolate_constant_risk(preds, final_time=final_time, n_windows=n_windows)

preds_ext = extrapolate_constant_risk(preds, final_time=final_time, intervals=interval)


def extrapolation_shape():
Expand Down
2 changes: 1 addition & 1 deletion xgbse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ._meta import XGBSEBootstrapEstimator


__version__ = "0.2.1"
__version__ = "0.2.2"

__all__ = [
"XGBSEDebiasedBCE",
Expand Down
6 changes: 2 additions & 4 deletions xgbse/_debiased_bce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# lib utils
from xgbse._base import XGBSEBaseEstimator, DummyLogisticRegression
from xgbse.converters import convert_data_to_xgb_format, convert_y
from xgbse.converters import convert_data_to_xgb_format, convert_y, hazard_to_survival

# at which percentiles will the KM predict
from xgbse.non_parametric import get_time_bins, calculate_interval_failures
Expand Down Expand Up @@ -346,9 +346,7 @@ def _predict_from_lr_list(self, lr_estimators, leaves_encoded, time_bins):

# converting these interval predictions
# to cumulative survival curve
preds = (1 - preds).cumprod(axis=1)

return preds
return hazard_to_survival(preds)

def predict(self, X, return_interval_probs=False):
"""
Expand Down
13 changes: 13 additions & 0 deletions xgbse/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,16 @@ def build_xgb_cox_dmatrix(X, T, E):
target = np.where(E, T, -T)

return xgb.DMatrix(X, label=target)


def hazard_to_survival(interval):
"""Convert hazards (interval probabilities of event) into survival curve
Args:
interval ([pd.DataFrame, np.array]): hazards (interval probabilities of event)
usually result of predict or result from _get_point_probs_from_survival
Returns:
[pd.DataFrame, np.array]: survival curve
"""
return (1 - interval).cumprod(axis=1)
38 changes: 17 additions & 21 deletions xgbse/extrapolation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import pandas as pd
from xgbse.non_parametric import _get_conditional_probs_from_survival
from xgbse.converters import hazard_to_survival


def extrapolate_constant_risk(survival, final_time, n_windows, lags=-1):
def extrapolate_constant_risk(survival, final_time, intervals, lags=-1):
"""
Extrapolate a survival curve assuming constant risk.
Expand All @@ -13,7 +14,7 @@ def extrapolate_constant_risk(survival, final_time, n_windows, lags=-1):
final_time (Float): Final time for extrapolation
n_windows (Int): Number of time windows to compute from last time window in survival to final_time
intervals (Int): Time in each interval between last time in survival dataframe and final time
lags (Int): Lags to compute constant risk.
if negative, will use the last "lags" values
Expand All @@ -24,28 +25,23 @@ def extrapolate_constant_risk(survival, final_time, n_windows, lags=-1):
pd.DataFrame: Survival dataset with appended extrapolated windows
"""

# calculating conditionals and risk at each time window
conditionals = _get_conditional_probs_from_survival(survival)
window_risk = 1 - conditionals
last_time = survival.columns[-1]
# creating windows for extrapolation
# here we sum intervals in times to exclude the last time, that already is in surv dataframe and
# to include final time in resulting dataframe
extrap_windows = np.arange(last_time + intervals, final_time + intervals, intervals)

# calculating window sizes
time_bins = window_risk.columns.to_series()
window_sizes = time_bins - time_bins.shift(1).fillna(0)
# calculating conditionals and hazard at each time window
hazards = _get_conditional_probs_from_survival(survival)

# using window sizes to calculate risk per unit time and average risk
risk_per_unit_time = np.power(window_risk, 1 / window_sizes)
average_risk = risk_per_unit_time.iloc[:, lags:].mean(axis=1)
# calculating avg hazard for desired lags
constant_haz = hazards.values[:, lags:].mean(axis=1).reshape(-1, 1)

# creating windows for extrapolation
last_time = survival.columns[-1]
extrap_windows = np.linspace(last_time, final_time, n_windows) - last_time
# repeat hazard for n_windows required
constant_haz = np.tile(constant_haz, len(extrap_windows))

# loop for extrapolated windows
for delta_t in extrap_windows:
constant_haz = pd.DataFrame(constant_haz, columns=extrap_windows)

# running constant risk extrapolation
extrap_survival = np.power(average_risk, delta_t) * survival.iloc[:, -1]
extrap_survival = pd.Series(extrap_survival, name=last_time + delta_t)
survival = pd.concat([survival, extrap_survival], axis=1)
hazards = pd.concat([hazards, constant_haz], axis=1)

return survival
return hazard_to_survival(hazards)
5 changes: 1 addition & 4 deletions xgbse/non_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,7 @@ def calculate_survival_func(E_sorted):
# product argument for surivial
survival_prod_arg = 1 - (E_sorted / at_risk)

# cumulative product of argument is the survival function
survival_func = np.cumprod(survival_prod_arg, axis=1)

return survival_func
return np.cumprod(survival_prod_arg, axis=1)


def calculate_confidence_intervals(E_sorted, survival_func, z):
Expand Down

0 comments on commit fa26300

Please sign in to comment.