diff --git a/R/api.R b/R/api.R index 2311301..bd5f436 100644 --- a/R/api.R +++ b/R/api.R @@ -91,7 +91,10 @@ target_get_trace <- function(name, req, filter = NULL, disaggregate = NULL, - scale = "natural") { + scale = "natural", + method = "auto", + span = 0.75, + k = 10) { logger::log_info(paste("Requesting data from", name, "with biomarker", biomarker)) dataset <- read_dataset(req, name, scale) @@ -112,7 +115,11 @@ target_get_trace <- function(name, groups <- split(dat, eval(parse(text = paste("~", disaggregate)))) nms <- names(groups) return(lapply(seq_along(groups), function(i) { - model <- with_warnings(model_out(groups[[i]], xcol)) + model <- with_warnings(model_out(groups[[i]], + xcol = xcol, + method = method, + span = span, + k = k)) list(name = jsonlite::unbox(nms[[i]]), model = model$output, raw = data_out(groups[[i]], xcol), @@ -120,7 +127,11 @@ target_get_trace <- function(name, })) } else { logger::log_info("Returning single trace") - model <- with_warnings(model_out(dat, xcol)) + model <- with_warnings(model_out(dat, + xcol = xcol, + method = method, + span = span, + k = k)) nm <- ifelse(is.null(filter), "all", filter) return(list(list(name = jsonlite::unbox(nm), model = model$output, @@ -149,16 +160,18 @@ read_dataset <- function(req, name, scale) { list(data = dat, xcol = xcol) } -model_out <- function(dat, xcol) { +model_out <- function(dat, xcol, method, span, k) { n <- nrow(dat) if (n == 0) { return(list(x = list(), y = list())) } - if (n > 1000) { - m <- mgcv::gam(value ~ s(eval(parse(text = xcol)), bs = "cs"), + if ((n > 1000 && method == "auto") || method == "gam") { + fmla <- sprintf("value ~ s(%s, bs = 'cs', k = %f)", xcol, k) + m <- mgcv::gam(eval(parse(text = fmla)), data = dat, method = "REML") } else { - m <- stats::loess(value ~ eval(parse(text = xcol)), data = dat, span = 0.75) + fmla <- sprintf("value ~ %s", xcol) + m <- stats::loess(fmla, data = dat, span = span) } range <- range(dat[, xcol], na.rm = TRUE) xseq <- range[1]:range[2] diff --git a/R/router.R b/R/router.R index a566270..0e9f86b 100644 --- a/R/router.R +++ b/R/router.R @@ -78,7 +78,10 @@ get_trace <- function() { target_get_trace, porcelain::porcelain_input_query(disaggregate = "string", filter = "string", - scale = "string"), + scale = "string", + method = "string", + span = "numeric", + k = "numeric"), returning = porcelain::porcelain_returning_json("DataSeries")) } diff --git a/R/utils.R b/R/utils.R index 8d1417a..f0b6a3c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -6,7 +6,11 @@ with_warnings <- function(expr) { invokeRestart("muffleWarning") } - val <- withCallingHandlers(expr, warning = w_handler) + e_handler <- function(e) { + porcelain::porcelain_stop(jsonlite::unbox(conditionMessage(e))) + } + + val <- withCallingHandlers(expr, warning = w_handler, error = e_handler) list(output = val, warnings = my_warnings) } diff --git a/tests/testthat/test-model.R b/tests/testthat/test-model.R new file mode 100644 index 0000000..61ed0b6 --- /dev/null +++ b/tests/testthat/test-model.R @@ -0,0 +1,68 @@ +test_that("model is gam if specified", { + dat <- data.frame(day = 1:100, value = rnorm(100)) + res <- model_out(dat, xcol = "day", method = "gam") + + m <- mgcv::gam(value ~ s(day, bs = "cs"), + data = dat, method = "REML") + xdf <- tibble::tibble(day = 1:100) + expected <- stats::predict(m, xdf) + + expect_true(all(res$y == expected)) +}) + +test_that("model is loess if specified", { + dat <- data.frame(day = 1:2000, value = rnorm(2000)) + res <- model_out(dat, xcol = "day", method = "loess") + + m <- stats::loess(value ~ day, data = dat, span = 0.75) + xdf <- tibble::tibble(day = 1:2000) + expected <- stats::predict(m, xdf) + + expect_true(all(res$y == expected)) +}) + +test_that("model is loess if not specified and n <= 1000", { + dat <- data.frame(day = 1:1000, value = rnorm(1000)) + res <- model_out(dat, xcol = "day") + + m <- stats::loess(value ~ day, data = dat, span = 0.75) + xdf <- tibble::tibble(day = 1:1000) + expected <- stats::predict(m, xdf) + + expect_true(all(res$y == expected)) +}) + +test_that("model is gam if not specified and n > 1000", { + dat <- data.frame(day = 1:1001, value = rnorm(1001)) + res <- model_out(dat, xcol = "day") + + m <- mgcv::gam(value ~ s(day, bs = "cs"), + data = dat, method = "REML") + xdf <- tibble::tibble(day = 1:1001) + expected <- stats::predict(m, xdf) + + expect_true(all(res$y == expected)) +}) + +test_that("model uses gam options", { + dat <- data.frame(day = 1:1001, value = rnorm(1001)) + res <- model_out(dat, xcol = "day", k = 5) + + m <- mgcv::gam(value ~ s(day, bs = "cs", k = 5), + data = dat, method = "REML") + xdf <- tibble::tibble(day = 1:1001) + expected <- stats::predict(m, xdf) + + expect_true(all(res$y == expected)) +}) + +test_that("model uses loess options", { + dat <- data.frame(day = 1:100, value = rnorm(100)) + res <- model_out(dat, xcol = "day", span = 0.5) + + m <- stats::loess(value ~ day, data = dat, span = 0.5) + xdf <- tibble::tibble(day = 1:100) + expected <- stats::predict(m, xdf) + + expect_true(all(res$y == expected)) +}) diff --git a/tests/testthat/test-read.R b/tests/testthat/test-read.R index 4862a5e..c0b0d27 100644 --- a/tests/testthat/test-read.R +++ b/tests/testthat/test-read.R @@ -227,3 +227,65 @@ test_that("can get log2 data", { )) }) +test_that("can use loess model options", { + dat <- data.frame(biomarker = "ab", + value = 1:5, + day = 1:5) + router <- build_routes(cookie_key) + local_add_dataset(dat, name = "testdataset") + res <- router$call(make_req("GET", + "/dataset/testdataset/trace/ab/", + qs = "method=loess&span=0.5", + HTTP_COOKIE = cookie)) + expect_equal(res$status, 200) + body <- jsonlite::fromJSON(res$body) + data <- body$data + + suppressWarnings(m <- stats::loess(value ~ day, data = dat, span = 0.5)) + xdf <- tibble::tibble(day = 1:5) + expected <- stats::predict(m, xdf) + expect_equal(unlist(data$model[1, "y"]), + jsonlite::fromJSON( + jsonlite::toJSON(expected) # convert to/from json for consistent rounding + )) +}) + +test_that("can use gam model options", { + dat <- data.frame(biomarker = "ab", + value = 1:5, + day = 1:5) + router <- build_routes(cookie_key) + local_add_dataset(dat, name = "testdataset") + res <- router$call(make_req("GET", + "/dataset/testdataset/trace/ab/", + qs = "method=gam&k=2", + HTTP_COOKIE = cookie)) + expect_equal(res$status, 200) + body <- jsonlite::fromJSON(res$body) + data <- body$data + suppressWarnings(m <- mgcv::gam(value ~ s(day, bs = "cs", k = 2), + data = dat, method = "REML")) + xdf <- tibble::tibble(day = 1:5) + expected <- stats::predict(m, xdf) + expect_equal(unlist(data$model[1, "y"]), + jsonlite::fromJSON( + jsonlite::toJSON(expected) # convert to/from json for consistent rounding + )) +}) + +test_that("error running the model results in a 400", { + dat <- data.frame(biomarker = "ab", + value = 1:5, + day = 1:5) + router <- build_routes(cookie_key) + local_add_dataset(dat, name = "testdataset") + res <- router$call(make_req("GET", + "/dataset/testdataset/trace/ab/", + qs = "method=gam&k=10", + HTTP_COOKIE = cookie)) + expect_equal(res$status, 400) + validate_failure_schema(res$body) + body <- jsonlite::fromJSON(res$body) + expect_equal(body$errors[1, "detail"], + "day has insufficient unique values to support 10 knots: reduce k.") +})