From 6a580ef46bf9b1bb2019ba68a445f3583934f0fe Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Sun, 7 Jan 2024 15:22:45 +0200 Subject: [PATCH] Add initial vignette --- .gitignore | 1 + DESCRIPTION | 9 +- R/model_methods.R | 5 +- vignettes/.gitignore | 2 + vignettes/Getting-Started.Rmd | 158 ++++++++++++++++++++++++++++++++++ 5 files changed, 170 insertions(+), 5 deletions(-) create mode 100644 vignettes/.gitignore create mode 100644 vignettes/Getting-Started.Rmd diff --git a/.gitignore b/.gitignore index 34c0754..a87f2c5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ **.o **.so inst/lib +inst/doc diff --git a/DESCRIPTION b/DESCRIPTION index 58c692e..44e1016 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -36,7 +36,10 @@ LinkingTo: RcppParallel, rapidjsonr Suggests: - testthat (>= 3.0.0), - callr, - future + testthat (>= 3.0.0), + callr, + future, + knitr, + rmarkdown Config/testthat/edition: 3 +VignetteBuilder: knitr diff --git a/R/model_methods.R b/R/model_methods.R index c3e9848..78781f0 100644 --- a/R/model_methods.R +++ b/R/model_methods.R @@ -183,8 +183,9 @@ loo.StanBase <- function(x, pointwise_ll_fun, additional_args = list(), do.call(pointwise_ll_fun, c(list(curr_draw[par_cols]), additional_args)) })) reff <- loo::relative_eff(exp(log_lik_draws), chain_id = x@draws$.chain) - suppressWarnings(loo_res <- loo::loo.matrix(log_lik_draws, r_eff = reff, ...)) + if (moment_match) { + suppressWarnings(loo_res <- loo::loo.matrix(log_lik_draws, r_eff = reff, ...)) log_lik_i <- function(x, i, parameter_name = "log_lik", ...) { log_lik_draws[, i] } @@ -206,6 +207,6 @@ loo.StanBase <- function(x, pointwise_ll_fun, additional_args = list(), ... ) } else { - loo_res + loo::loo.matrix(log_lik_draws, r_eff = reff, ...) } } diff --git a/vignettes/.gitignore b/vignettes/.gitignore new file mode 100644 index 0000000..097b241 --- /dev/null +++ b/vignettes/.gitignore @@ -0,0 +1,2 @@ +*.html +*.R diff --git a/vignettes/Getting-Started.Rmd b/vignettes/Getting-Started.Rmd new file mode 100644 index 0000000..3d50e98 --- /dev/null +++ b/vignettes/Getting-Started.Rmd @@ -0,0 +1,158 @@ +--- +title: "Getting Started with StanEstimators" +output: + html_vignette: + toc: yes +author: "Andrew Johnson" +date: "`r Sys.Date()`" +vignette: > + %\VignetteIndexEntry{Getting Started with StanEstimators} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +```{r setup} +library(StanEstimators) +``` + +# Introduction + +This vignette provides a basic introduction to the features and capabilities of the `StanEstimators` package. `StanEstimators` provides an interface for estimating R functions using the various algorithms and methods implemented by [Stan](https://mc-stan.org/). + +The `StanEstimators` package supports all algorithms implemented by Stan. The available methods, and their corresponding functions, are: + +Stan Method | `StanEstimators` Function +----------- | ------------------------- +MCMC Sampling | `stan_sample` +Maximum Likelihood Estimation | `stan_optimize` +Variational Inference | `stan_variational` +Pathfinder | `stan_pathfinder` +Laplace Approximation | `stan_laplace` + +## Motivations + +While Stan is powerful, it can have a high barrier to entry for new users - as they need to translate their existing models/functions into the Stan language. `StanEstimators` provides a simple interface for users to estimate their R functions using Stan's algorithms without needing to learn the Stan language. Additionally, it also allows for users to 'sanity-check' their Stan code by comparing the results of their Stan code to the results of their original R function. Finally, it can be difficult to interface Stan with existing R packages and functions - as this requires bespoke Stan models for the problem at hand, something that may be too great a time investment for many users. + +`StanEstimators` aims to address these issues. + +## Estimating a Model + +As an example, we will use the popular 'Eight Schools' example from the Stan documentation. This example is used to demonstrate the use of hierarchical models in Stan. The model is defined as: + +$$ +y_j &\sim N(\theta_j, \sigma_j), \quad j=1,\ldots,8 \\ +\theta_j &\sim \N(\mu, \tau), \quad j=1,\ldots,8 +$$ + +With data: + +School | Estimate ($y_j$) | Standard Error ($\sigma_j$) +------ | -------- | -------------- +A | 28 | 15 +B | 8 | 10 +C | -3 | 16 +D | 7 | 11 +E | -1 | 9 +F | 1 | 11 +G | 18 | 10 +H | 12 | 18 + +```{r} +y <- c(28, 8, -3, 7, -1, 1, 18, 12) +sigma <- c(15, 10, 16, 11, 9, 11, 10, 18) +``` + +### Specifying the Function + +To specify this as a function compatible with `StanEstimators`, we need to define a function that takes in a vector of parameters as the first argument and returns a single value (generally the joint log-likelihood): + +```{r} +eight_schools_lpdf <- function(v, y, sigma) { + mu <- v[1] + tau <- v[2] + eta <- v[3:10] + + # Use the non-centered parameterisation for eta + # https://mc-stan.org/docs/stan-users-guide/reparameterization.html + theta <- mu + tau * eta + + sum(dnorm(eta, mean = 0, sd = 1, log = TRUE)) + + sum(dnorm(y, mean = theta, sd = sigma, log = TRUE)) +} +``` + +Note that any additional data required by the function are passed as additional arguments. In this case, we need to pass the data for $y$ and $\sigma$. Alternatively, the function can assume that these data will be available in the global environment, rather than passed as arguments. + +### Estimating the Function + +To estimate our model, we simply pass the function to the relevant `StanEstimators` function. For example, to estimate the function using MCMC sampling, we use the `stan_sample` function (which uses the No-U-Turn Sampler by default). + +#### Parameter Bounds + +Because we are estimating a standard deviation in our model ($\tau$), we need to ensure that it is positive. We can do this by specifying a lower bound of 0 for $\tau$. This is done by passing a vector of lower bounds to the `lower` argument, with the corresponding elements of the vector matching the order of the parameters in the function. Noting that $\tau$ is the second parameter in the function, and we do not want to specify a lower bound for any other parameters, we can specify the lower bounds as: + +```{r} +lower <- c(-Inf, 0, rep(-Inf, 8)) +``` + +#### Initial Values + +We also need to specify initial values for the parameters. These are passed as a vector to the `init` argument, with the corresponding elements of the vector matching the order of the parameters in the function. This is currently required for determining the total number of parameters to estimate, and will be removed in a future version of `StanEstimators`. For now we will just specify all initial values as 1: + +```{r} +init <- rep(1, 10) +``` + +#### Running the Model + +We can now pass these arguments to the `stan_sample` function to estimate our model. We will use the default number of chains (4) and the default number of warmup iterations (1000) and sampling iterations (1000). + +```{r, results=FALSE, message=FALSE, warning=FALSE} +fit <- stan_sample(eight_schools_lpdf, + par_inits = init, + additional_args = list(y = y, sigma = sigma), + lower = lower) +``` + +### Inspecting the Results + +The estimates are stored in the `fit` object using the [`posterior::draws` format](https://cran.r-project.org/web/packages/posterior/vignettes/posterior.html). We can inspect the estimates using the `summary` function (which calls `posterior::summarise_draws`): + +```{r} +summary(fit) +``` + +## Model Checking and Comparison - Leave-One-Out Cross-Validation (LOO-CV) + +`StanEstimators` also supports the use of the [loo](https://mc-stan.org/loo/articles/loo2-example.html) package for model checking and comparison. To use this, we need to specify a function which returns the pointwise log-likelihood for each observation in the data - as our original function returns the sum of all log-likelihoods. + +For our model, we can define this function as: + +```{r} +eight_schools_pointwise <- function(v, y, sigma) { + mu <- v[1] + tau <- v[2] + eta <- v[3:10] + + # Use the non-centered parameterisation for eta + # https://mc-stan.org/docs/stan-users-guide/reparameterization.html + theta <- mu + tau * eta + + # Only the log-likelihood for the outcome variable + dnorm(y, mean = theta, sd = sigma, log = TRUE) +} +``` + +This can then be used with the `loo` function to calculate the LOO-CV estimate: + +```{r} +loo(fit, pointwise_ll_fun = eight_schools_pointwise, + additional_args = list(y = y, sigma = sigma)) +```