-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchapter15.qmd
322 lines (240 loc) · 8.22 KB
/
chapter15.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
---
title: "Chapter 15 - Machine Learning in Survival Analysis"
---
## Slides
Lecture slides [here](chap15.html){target="_blank"}. (To convert html to pdf, press E $\to$ Print $\to$ Destination: Save to pdf)
## Base R Code
```{r}
#| code-fold: true
#| code-summary: "Show the code"
#| eval: false
##################################################################
# This code generates all numerical results in chapter 15. ##
##################################################################
# library("glmpath")
library("glmnet")
library(survival)
# read in the complete data
gbc <- read.table("Data//German Breast Cancer Study//gbc.txt")
# subset to first event data
# Sort the data by time within each id
o <- order(gbc$id,gbc$time)
gbc <- gbc[o,]
# get the first row for each id
data.CE <- gbc[!duplicated(gbc$id),]
#set status=1 if status==2 or 1
data.CE$status <- (data.CE$status > 0 ) + 0
# create a binary variable for age<=40 years
data.CE$age40 <- (data.CE$age <= 40) + 0
n <- nrow(data.CE)
head(data.CE)
## select training set
# N=400
set.seed(1234)
ind <- sample(1:n)[1:400]
train <- data.CE[ind,]
test <- data.CE[-ind,]
# Predictor list for Cox-lasso
pred_list <- c("hormone", "age40", "meno", "size", "grade",
"nodes", "prog", "estrg")
# covariate matrix
Z <- as.matrix(train[,pred_list])
# time and status variables
time <- train$time
status <- train$status
# dimension of covariate matrix
dim(Z)
# 400x 8
# compute the covariate path as a function of lambda
# alpha=1: L_1 penalty (lasso)
obj <- glmnet(Z,Surv(time, status), family="cox", alpha = 1)
summary(obj)
# compute 10-fold (default) cross-validation
obj.cv <- cv.glmnet(Z,Surv(time, status), family="cox", alpha = 1)
# Figure parameters
par(mfrow = c(1,2))
# par(cex = 1.2)
# plot the covariate paths as
# a function of log-lambda
library(plotmo) # for plot_glmnet
plotmo::plot_glmnet(obj, lwd=2)
# plot(obj,xvar="lambda",lwd=2,label=TRUE)
# plot the the validation error (partial-likelihood deviance)
# as a function of log-lambda
plot(obj.cv)
# the optimal lambda
obj.cv$lambda.min
log(obj.cv$lambda.min)
# the beta at optimal lambda
beta <- coef(obj.cv, s = "lambda.min")
# the non-zero coefficients
beta.selected <- beta[abs(beta[,1])>0,]
# print out the non-zero coefficients
beta.selected
# number of non-zero coefficients
length(beta.selected)
# Refit the training data using the variables selected
selected <- names(beta.selected)
obj <- coxph(Surv(train$time, train$status) ~ as.matrix(train[,selected]))
###############################################################################
## A function that takes on a coxph object obj and a (n x p) test covariate
## matrix Z and outputs the predicted survival function
## Output:
### St: (n x m) matrix with the ith row the predicted survival rates for the
### ith subject;
### t: m-vector of times.
###############################################################################
predsurv_cox <- function(obj, Z){
beta=obj$coefficients
bhaz=basehaz(obj,centered=F)
L=bhaz$hazard
t=c(0, bhaz$time)
St=cbind(rep(1,nrow(Z)),exp(-exp(Z%*%beta)%*%t(L)))
return(list(St=St,t=t))
}
## Get the predicted survival rates for the test set by Cox-lasso
pred_surv <- predsurv_cox(obj, Z=as.matrix(test[,selected]))
St <- pred_surv$St
t <- pred_surv$t
## Get the KM estimates for censoring distribution
G_obj <- summary(survfit(Surv(time, 1-status)~1, data=test))
tc <- c(0,G_obj$time)
Gtc <- c(1,G_obj$surv)
#####################################################
# A function calculating the Brier score BS(tau)
# Input:
# tau: time at which the score is evaluated
# (time, status): observed test outcomes
# St,t: predicted survival rates
# Gtc,tc: KM estimates for censoring distributions
######################################################
BSfun=function(tau,time,status,St,t,Gtc,tc){
n=length(time)
BSvec=rep(NA,n)
pos=sum(t<=tau)
for (i in 1:n){
X_i=time[i]
S_i=St[i,pos]
G_i=ifelse(X_i<=tau&&status[i]==1,
Gtc[sum(tc<=X_i)],
Gtc[sum(tc<=tau)])
BSvec[i]=ifelse(X_i<=tau&&status[i]==1,
S_i^2/G_i, ifelse(
X_i>tau,(1-S_i)^2/G_i,0
)
)
}
return(mean(BSvec,na.rm=T))
}
# Compute the Brier score at tau=12 to 60 months
# under Cox-lasso
tau_list <- 12:60
BS_tau <- rep(NA,length(tau_list))
for(i in 1:length(tau_list)){
BS_tau[i] <- BSfun(tau=tau_list[i],test$time,test$status,St,t,Gtc,tc)
}
plot(tau_list,BS_tau,type='l',lwd=2)
# Full Cox model
obj_full <- coxph(Surv(train$time, train$status)~as.matrix(train[,pred_list]))
# Get the predicted survival rates
pred_surv_full <- predsurv_cox(obj_full, Z=as.matrix(test[,pred_list]))
St_full <- pred_surv_full$St
t_full <- pred_surv_full$t
# Compute the Brier score at tau=12 to 60 months
# under Cox-full
BS_tau_full <- rep(NA,length(tau_list))
for(i in 1:length(tau_list)){
BS_tau_full[i] <- BSfun(tau=tau_list[i],test$time,test$status,St_full,t_full,Gtc,tc)
}
##############################
# Survival trees #
##############################
library(rpart)
library(rpart.plot)
### Build survival tree with cross-validation error ###
set.seed(12345)
# Conduct 10-fold cross-validation (xval = 10),
# with minimum terminal node size 2 (minbucket = 2)
obj <- rpart(Surv(time, status) ~ hormone+meno+size+grade+nodes+
prog+estrg+age,
control = rpart.control(xval = 10, minbucket = 2, cp=0),
data = train)
printcp(obj)
# CP nsplit rel.error xerror xstd
# 1 0.07556835 0 1.00000 1.00411 0.046231
# 2 0.03720019 1 0.92443 0.96817 0.047281
# 3 0.02661914 2 0.88723 0.95124 0.046567
# 4 0.01716925 3 0.86061 0.92745 0.046606
# 5 0.01398306 4 0.84344 0.92976 0.047514
# 6 0.01394869 5 0.82946 0.93941 0.048404
# 7 0.01055028 9 0.77120 0.97722 0.052133
# 8 0.01053135 10 0.76065 1.00140 0.054295
# summary(obj)
# cross-validation results
cptable <- obj$cptable
# complexity parameter values
CP <- cptable[, 1]
# obtain the optimal parameter
cp.opt <- CP[which.min(cptable[, 4])]
# Prune the tree
fit <- prune(obj, cp = cp.opt)
par(mfrow=c(1,2))
# plot the pruned tree structure
rpart.plot(fit)
# plot the KM curves for the terminal nodes
km <- survfit(Surv(time, status) ~ fit$where, train)
plot(km, lty = 1:4, mark.time = FALSE,
xlab = "Years", ylab = "Progression",lwd=2,cex.lab=1.2,cex.lab=1.2)
legend("bottomleft", paste('Node', sort(unique(fit$where))), lty = 1:4,lwd=2,cex=1.2)
# Get the KM estimates for the outcome in each terminal node
tmp <- summary(km)
tmp.strata <- as.integer(sub(".*=", "", tmp$strata))
tmp.t <- tmp$time
tmp.surv <- tmp$surv
# Number of terminal nodes
TN <- unique(tmp.strata)
N <- length(TN)
# Combine the predicted survival rates together,
# as functions of t
t <- sort(unique(tmp.t))
m <- length(t)
fitted_surv=matrix(NA,m,N)
for (j in 1:m){
tj <- t[j]
for (k in 1:N){
tk <- c(0,tmp.t[tmp.strata==TN[k]])
survk <- c(1,tmp.surv[tmp.strata==TN[k]])
fitted_surv[j,k] <- survk[sum(tk<=tj)]
}
}
# Get the terminal node prediction
# for the test data
library(treeClust)
test_term <- rpart.predict.leaves(fit, test)
n <- length(test_term)
St_tree <- matrix(NA, n, m)
for (k in 1:N){
ind <- which(test_term==TN[k])
St_tree[ind,] <- matrix(fitted_surv[,k], nrow=length(ind),
ncol=m, byrow=TRUE)
}
## Get the KM estimates for censoring distribution
G_obj <- summary(survfit(Surv(time, 1-status)~1, data=test))
tc <- c(0,G_obj$time)
Gtc <- c(1,G_obj$surv)
# Compute the Brier score at tau=12 to 60 months
# under the pruned survival tree
tau_list <- 12:60
BS_tau_tree <- rep(NA,length(tau_list))
for(i in 1:length(tau_list)){
BS_tau_tree[i] <- BSfun(tau=tau_list[i],test$time,test$status,St_tree,t,Gtc,tc)
}
# Plot the Bier score curves for Cox-lasso, Cox-full, and survival tree
par(mfrow=c(1,1))
plot(tau_list/12,BS_tau_tree,type='l',lwd=2,col="red",cex.axis=1.2,
cex.lab=1.2,xlab="t (years)", ylab="BS(t)",ylim=c(0,0.25))
lines(tau_list/12,BS_tau,lty=1,lwd=2,col="blue")
lines(tau_list/12,BS_tau_full,lty=1,lwd=2,col="black")
legend("bottomright",1,col=c("red","blue","black"),lwd=2,
c("Survival Tree","Cox-lasso","Cox-full"),cex=1.2)
```