Skip to content

Commit ef3fbc6

Browse files
gravestigithub-actions[bot]dependabot-preview[bot]
authored
Speed up (#119)
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: 27856297+dependabot-preview[bot]@users.noreply.github.com <27856297+dependabot-preview[bot]@users.noreply.github.com>
1 parent c9c1440 commit ef3fbc6

File tree

11 files changed

+211
-75
lines changed

11 files changed

+211
-75
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# R specific hooks: https://github.com/lorenzwalthert/precommit
44
repos:
55
- repo: https://github.com/lorenzwalthert/precommit
6-
rev: v0.4.2
6+
rev: v0.4.3
77
hooks:
88
- id: style-files
99
args: [--style_pkg=styler, --style_fun=tidyverse_style]
@@ -91,6 +91,6 @@ repos:
9191
files: '\.Rhistory|\.RData|\.Rds|\.rds$'
9292
# `exclude: <regex>` to allow committing specific files.
9393
- repo: https://github.com/igorshubovych/markdownlint-cli
94-
rev: v0.40.0
94+
rev: v0.41.0
9595
hooks:
9696
- id: markdownlint

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ importFrom(stats,dbeta)
5151
importFrom(stats,dbinom)
5252
importFrom(stats,integrate)
5353
importFrom(stats,optimize)
54+
importFrom(stats,pbeta)
5455
importFrom(stats,quantile)
5556
importFrom(stats,rbeta)
5657
importFrom(stats,rbinom)

R/dbetabinom.R

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,20 @@ dbetaMix <- function(x, par, weights, log = FALSE) {
124124
assert_numeric(weights, lower = 0, upper = 1, finite = TRUE, any.missing = FALSE)
125125
assert_true(all.equal(sum(weights), 1))
126126
assert_true(identical(length(weights), nrow(par)))
127-
ret <- sum(weights * dbeta(x, par[, 1], par[, 2]))
127+
degree <- length(weights)
128+
129+
component_densities <- matrix(
130+
dbeta(rep(x, each = degree), par[, 1], par[, 2]),
131+
nrow = degree,
132+
ncol = length(x)
133+
)
134+
ret <- as.numeric(weights %*% component_densities)
128135
if (log) {
129136
log(ret)
130137
} else {
131138
ret
132139
}
133140
}
134-
dbetaMix <- Vectorize(dbetaMix, vectorize.args = "x")
135141

136142

137143
#' Beta-Mixture CDF
@@ -155,15 +161,25 @@ dbetaMix <- Vectorize(dbetaMix, vectorize.args = "x")
155161
#' @note `q` can be a vector.
156162
#'
157163
#' @example examples/pbetaMix.R
164+
#' @importFrom stats pbeta
158165
#' @export
159166
pbetaMix <- function(q, par, weights, lower.tail = TRUE) {
160-
assert_number(q, lower = 0, upper = 1, finite = TRUE)
167+
assert_numeric(q, lower = 0, upper = 1, finite = TRUE)
161168
assert_numeric(weights, lower = 0, upper = 1, finite = TRUE)
162169
assert_matrix(par)
163170
assert_flag(lower.tail)
164-
sum(weights * pbeta(q, par[, 1], par[, 2], lower.tail = lower.tail))
171+
.pbetaMix(q = q, par = par, weights = weights, lower.tail = lower.tail)
172+
}
173+
174+
.pbetaMix <- function(q, par, weights, lower.tail) {
175+
degree <- length(weights)
176+
component_p <- matrix(
177+
pbeta(rep(q, each = degree), par[, 1], par[, 2], lower.tail = lower.tail),
178+
nrow = degree,
179+
ncol = length(q)
180+
)
181+
as.numeric(weights %*% component_p)
165182
}
166-
pbetaMix <- Vectorize(pbetaMix, vectorize.args = "q")
167183

168184

169185
#' Beta-Mixture Quantile Function
@@ -186,23 +202,33 @@ pbetaMix <- Vectorize(pbetaMix, vectorize.args = "q")
186202
#' @example examples/qbetaMix.R
187203
#' @export
188204
qbetaMix <- function(p, par, weights, lower.tail = TRUE) {
189-
f <- function(pi) {
190-
pbetaMix(q = pi, par = par, weights = weights, lower.tail = lower.tail) - p
191-
}
192-
# Note: we give the lower and upper function values here in order to avoid problems for
193-
# p = 0 or p = 1.
194-
unirootResult <- uniroot(
195-
f,
196-
lower = 0, upper = 1,
197-
f.lower = -p, f.upper = 1 - p,
198-
tol = sqrt(.Machine$double.eps) # Increase the precision over default `tol`.
199-
)
200-
if (unirootResult$iter < 0) {
201-
NA
202-
} else {
203-
assert_number(unirootResult$root)
204-
assert_true(all.equal(f(unirootResult$root), 0, tolerance = 0.001))
205-
unirootResult$root
206-
}
205+
assert_numeric(p, lower = 0, upper = 1)
206+
assert_numeric(weights, lower = 0, upper = 1, finite = TRUE)
207+
assert_matrix(par)
208+
assert_flag(lower.tail)
209+
210+
grid <- seq(0, 1, len = 31)
211+
f_grid <- .pbetaMix(grid, par, weights, lower.tail = lower.tail)
212+
213+
sapply(p, function(p) {
214+
# special cases
215+
if (p == 0) {
216+
return(0)
217+
}
218+
if (p == 1) {
219+
return(1)
220+
}
221+
222+
diff <- f_grid - p
223+
pos <- diff > 0
224+
grid_interval <- c(grid[!pos][which.max(diff[!pos])], grid[pos][which.min(diff[pos])])
225+
226+
uniroot(
227+
f = function(q) .pbetaMix(q, par, weights, lower.tail = lower.tail) - p,
228+
interval = grid_interval,
229+
f.lower = -p,
230+
f.upper = 1 - p,
231+
tol = sqrt(.Machine$double.eps)
232+
)$root
233+
})
207234
}
208-
qbetaMix <- Vectorize(qbetaMix, vectorize.args = "p")

R/postprob.R

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,25 @@ postprob <- function(x, n, p, parE = c(1, 1), weights, betamixPost, log.p = FALS
8181
if (missing(weights)) {
8282
weights <- rep(1, nrow(parE))
8383
}
84-
betamixPost <- h_getBetamixPost(
85-
x = x,
84+
betamixPost <- lapply(
85+
x,
86+
h_getBetamixPost,
8687
n = n,
8788
par = parE,
8889
weights = weights
8990
)
91+
} else {
92+
assert_list(betamixPost)
93+
assert_names(names(betamixPost), identical.to = c("par", "weights"))
94+
betamixPost <- list(betamixPost)
9095
}
91-
assert_list(betamixPost)
92-
assert_names(names(betamixPost), identical.to = c("par", "weights"))
93-
ret <- with(
96+
97+
ret <- vapply(
9498
betamixPost,
95-
pbetaMix(q = p, par = par, weights = weights, lower.tail = FALSE)
99+
FUN = function(bmp) {
100+
.pbetaMix(q = p, par = bmp$par, weights = bmp$weights, lower.tail = FALSE)
101+
},
102+
FUN.VALUE = numeric(length(p))
96103
)
97104

98105
if (log.p) {
@@ -101,4 +108,3 @@ postprob <- function(x, n, p, parE = c(1, 1), weights, betamixPost, log.p = FALS
101108
ret
102109
}
103110
}
104-
postprob <- Vectorize(postprob, vectorize.args = "x")

R/postprobDist.R

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ postprobDist <- function(x,
164164
if (missing(weightsS)) {
165165
weightsS <- rep(1, nrow(parS))
166166
}
167-
assert_number(n, lower = x, finite = TRUE)
167+
assert_number(n, lower = min(x), finite = TRUE)
168168
assert_numeric(x, lower = 0, upper = n, finite = TRUE)
169169
assert_number(nS, lower = 0, finite = TRUE)
170170
assert_number(xS, lower = 0, upper = nS, finite = TRUE)
@@ -174,10 +174,10 @@ postprobDist <- function(x,
174174
assert_numeric(weightsS, lower = 0, finite = TRUE)
175175
assert_numeric(parE, lower = 0, finite = TRUE)
176176
assert_numeric(parS, lower = 0, finite = TRUE)
177-
activeBetamixPost <- h_getBetamixPost(x = x, n = n, par = parE, weights = weights)
177+
178+
activeBetamixPost <- lapply(x, function(x) h_getBetamixPost(x = x, n = n, par = parE, weights = weights))
179+
178180
controlBetamixPost <- h_getBetamixPost(x = xS, n = nS, par = parS, weights = weightsS)
179-
assert_names(names(activeBetamixPost), identical.to = c("par", "weights"))
180-
assert_names(names(controlBetamixPost), identical.to = c("par", "weights"))
181181
if (relativeDelta) {
182182
epsilon <- .Machine$double.xmin
183183
integrand <- h_integrand_relDelta
@@ -186,26 +186,26 @@ postprobDist <- function(x,
186186
integrand <- h_integrand
187187
}
188188
bounds <- h_get_bounds(controlBetamixPost = controlBetamixPost)
189-
intRes <- integrate(
190-
f = integrand,
191-
lower =
192-
max(
193-
bounds[1],
194-
ifelse(relativeDelta, 0, 0 - delta)
195-
),
196-
upper =
197-
min(
198-
ifelse(relativeDelta, 1, 1 - delta),
199-
bounds[2]
200-
),
201-
delta = delta,
202-
activeBetamixPost = activeBetamixPost,
203-
controlBetamixPost = controlBetamixPost
189+
190+
integral_results <- lapply(
191+
seq_along(x),
192+
function(i, this_posterior = activeBetamixPost, input_x = x) {
193+
intRes <- integrate(
194+
f = integrand,
195+
lower = max(bounds[1], ifelse(relativeDelta, 0, 0 - delta)),
196+
upper = min(ifelse(relativeDelta, 1, 1 - delta), bounds[2]),
197+
delta = delta,
198+
activeBetamixPost = this_posterior[[i]],
199+
controlBetamixPost = controlBetamixPost
200+
)
201+
if (intRes$message == "OK") {
202+
intRes$value
203+
} else {
204+
warning("Integration failed for posterior based on x =", input_x[i], "\n", intRes$message)
205+
NA_real_
206+
}
207+
}
204208
)
205-
if (intRes$message == "OK") {
206-
intRes$value
207-
} else {
208-
stop(intRes$message)
209-
}
209+
210+
unlist(integral_results)
210211
}
211-
postprobDist <- Vectorize(postprobDist, vectorize.args = "x")

examples/postprob.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,13 @@ postprob(
1818
),
1919
weights = c(0.6, 0.4)
2020
)
21+
22+
postprob(
23+
x = 0:23, n = 23, p = 0.60,
24+
par =
25+
rbind(
26+
c(0.6, 0.4),
27+
c(1, 1)
28+
),
29+
weights = c(0.6, 0.4)
30+
)

man/postprob.Rd

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-dbetabinom.R

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,23 @@ test_that("pbetaMix gives the correct number result with beta-mixture", {
8484
expect_equal(result, 0.4768404, tolerance = 1e-5)
8585
})
8686

87+
test_that("pbetaMix works for edge cases", {
88+
result_ushape <- pbetaMix(
89+
q = c(0, 1),
90+
par = rbind(c(0.2, 0.4), c(3, .3)),
91+
weights = c(0.6, 0.4)
92+
)
93+
expect_equal(result_ushape, c(0, 1))
94+
95+
result_vshape <- pbetaMix(
96+
q = c(0, 1),
97+
par = rbind(c(9, 4), c(1, 1)),
98+
weights = c(0.6, 0.4)
99+
)
100+
expect_equal(result_vshape, c(0, 1))
101+
})
102+
103+
87104
test_that("The complement of pbetaMix can be derived with a different lower.tail flag", {
88105
result <- pbetaMix(
89106
q = 0.3,
@@ -184,6 +201,32 @@ test_that("dbetaMix gives the correct result as dbeta", {
184201
expect_equal(result, result2, tolerance = 1e-4)
185202
})
186203

204+
test_that("dbetaMix handles edge cases", {
205+
result_inf <- dbetaMix(
206+
x = c(0, 1), par = rbind(c(0.2, 0.4), c(1, 1)),
207+
weights = c(0.6, 0.4)
208+
)
209+
expect_equal(result_inf, c(Inf, Inf))
210+
211+
result_finite <- dbetaMix(
212+
x = c(0, 1), par = rbind(c(2, 4), c(1, 1)),
213+
weights = c(0.6, 0.4)
214+
)
215+
expect_equal(result_finite, c(0.4, 0.4))
216+
217+
result_right <- dbetaMix(
218+
x = c(0, 1), par = rbind(c(0, 4), c(1, 1)),
219+
weights = c(0.6, 0.4)
220+
)
221+
expect_equal(result_right, c(Inf, 0.4))
222+
223+
result_right <- dbetaMix(
224+
x = c(NA, 1), par = rbind(c(0, 4), c(1, 1)),
225+
weights = c(0.6, 0.4)
226+
)
227+
expect_equal(result_right, c(NA, 0.4))
228+
})
229+
187230
# h_getBetamixPost ----
188231

189232
test_that("h_getBetamixPost gives the correct beta-mixture parameters", {

tests/testthat/test-ocPredprobDist.R

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,28 @@ test_that("h_decision_two_predprobDist gives correct result and list", {
6565

6666
test_that("ocPredprobDist gives correct result and list when relativeDelta = FALSE", {
6767
set.seed(1989)
68-
result <- ocPredprobDist(
69-
nnE = c(10, 20, 30),
70-
truep = 0.40,
71-
deltaE = 0.5,
72-
deltaF = 0.5,
73-
relativeDelta = FALSE,
74-
tT = 0.6,
75-
phiU = 0.80,
76-
phiFu = 0.7,
77-
parE = c(1, 1),
78-
parS = c(5, 25),
79-
weights = 1,
80-
weightsS = 1,
81-
sim = 50,
82-
nnF = c(10, 20, 30),
83-
wiggle = TRUE,
84-
decision1 = FALSE
68+
expect_warning(
69+
{
70+
result <- ocPredprobDist(
71+
nnE = c(10, 20, 30),
72+
truep = 0.40,
73+
deltaE = 0.5,
74+
deltaF = 0.5,
75+
relativeDelta = FALSE,
76+
tT = 0.6,
77+
phiU = 0.80,
78+
phiFu = 0.7,
79+
parE = c(1, 1),
80+
parS = c(5, 25),
81+
weights = 1,
82+
weightsS = 1,
83+
sim = 50,
84+
nnF = c(10, 20, 30),
85+
wiggle = TRUE,
86+
decision1 = FALSE
87+
)
88+
},
89+
"achieve convergence"
8590
)
8691
result_sum <- sum(result$oc[5:7])
8792
expect_equal(result_sum, 1)

tests/testthat/test-postprob.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,18 @@ test_that("postprob gives incrementally higher values with increased x", {
5353
)
5454
expect_true(is_lower < is_higher)
5555
})
56+
57+
test_that("postprob works with vector x", {
58+
result <- postprob(x = 0:23, n = 23, p = 0.60, par = c(0.6, 0.4))
59+
expected <- c(
60+
1.12066620085448e-10, 6.73786529927603e-09, 1.45879637562279e-07,
61+
1.86374536434781e-06, 1.64656040420248e-05, 0.000108838231763851,
62+
0.000564103325535708, 0.00236446983272442, 0.00819197194809839,
63+
0.0238449136766029, 0.0590640325657381, 0.125847456119664,
64+
0.232931221473374, 0.378259188739121, 0.54495891589689,
65+
0.705949748288983, 0.835980805221058, 0.922929283049132,
66+
0.970355725500809, 0.991009176245894, 0.997963909660055,
67+
0.999685712592687, 0.999972679748126, 0.99999934483779
68+
)
69+
expect_equal(result, expected, tolerance = 1e-5)
70+
})

0 commit comments

Comments
 (0)