diff --git a/README.Rmd b/README.Rmd index b0efaf6..6d34c8b 100644 --- a/README.Rmd +++ b/README.Rmd @@ -27,6 +27,23 @@ As Stan's algorithms are gradient-based, function gradients can be automatically calculated using finite-differencing or the user can provide a function for analytical calculation. +## Installation + +You can install pre-built binaries using: + +```{r, eval=FALSE} +# we recommend running this is a fresh R session or restarting your current session +install.packages("StanEstimators", + repos = c("https://andrjohns.github.io/StanEstimators/", getOption("repos"))) +``` + +Or you can build from source using: + +```{r, eval=FALSE} +# install.packages("remotes") +remotes::install_github("andrjohns/StanEstimators") +``` + ## Usage Consider the goal of estimating the mean and standard deviation of a normal diff --git a/README.md b/README.md index 2642350..3f8d01c 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,23 @@ As Stan’s algorithms are gradient-based, function gradients can be automatically calculated using finite-differencing or the user can provide a function for analytical calculation. +## Installation + +You can install pre-built binaries using: + +``` r +# we recommend running this is a fresh R session or restarting your current session +install.packages("StanEstimators", + repos = c("https://andrjohns.github.io/StanEstimators/", getOption("repos"))) +``` + +Or you can build from source using: + +``` r +# install.packages("remotes") +remotes::install_github("andrjohns/StanEstimators") +``` + ## Usage Consider the goal of estimating the mean and standard deviation of a @@ -87,14 +104,14 @@ iterations ``` r unlist(fit@timing) #> warmup sampling -#> 0.558 0.426 +#> 0.855 0.725 summary(fit) #> # A tibble: 3 × 10 #> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail #> -#> 1 lp__ -1.05e3 -1.05e3 1.07 0.741 -1.05e3 -1.05e3 1.01 452. 555. -#> 2 pars[1] 1.01e1 1.01e1 0.0862 0.0876 9.97e0 1.02e1 1.00 861. 648. -#> 3 pars[2] 1.97e0 1.97e0 0.0670 0.0681 1.86e0 2.08e0 1.00 973. 623. +#> 1 lp__ -1.06e3 -1.06e3 1.08 0.801 -1.06e3 -1.06e3 1.00 442. 420. +#> 2 pars[1] 9.97e0 9.97e0 0.0956 0.101 9.82e0 1.01e1 0.999 1034. 765. +#> 3 pars[2] 2.02e0 2.02e0 0.0670 0.0661 1.92e0 2.13e0 1.00 1026. 521. ``` Estimation time can be improved further by providing a gradient @@ -114,14 +131,14 @@ Which shows that the estimation time was dramatically improved, now ``` r unlist(fit_grad@timing) #> warmup sampling -#> 0.079 0.093 +#> 0.130 0.117 summary(fit_grad) #> # A tibble: 3 × 10 #> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail #> -#> 1 lp__ -1.05e3 -1.05e3 1.11 0.764 -1.05e3 -1.05e3 1.00 421. 564. -#> 2 pars[1] 1.01e1 1.01e1 0.0927 0.100 9.95e0 1.03e1 0.999 882. 588. -#> 3 pars[2] 1.97e0 1.97e0 0.0652 0.0625 1.86e0 2.08e0 1.00 724. 591. +#> 1 lp__ -1.06e3 -1.06e3 1.06 0.756 -1.06e3 -1.06e3 0.999 486. 667. +#> 2 pars[1] 9.97e0 9.97e0 0.0917 0.0828 9.81e0 1.01e1 1.00 1149. 758. +#> 3 pars[2] 2.02e0 2.02e0 0.0662 0.0682 1.92e0 2.14e0 1.00 1070. 663. ``` ### Optimization @@ -139,10 +156,10 @@ opt_grad <- stan_optimize(loglik_fun, inits, additional_args = list(y), ``` r summary(opt_fd) #> lp__ pars[1] pars[2] -#> 1 -1046.14 10.1042 1.96079 +#> 1 -1059.86 9.96874 2.01536 summary(opt_grad) #> lp__ pars[1] pars[2] -#> 1 -1046.14 10.1042 1.96079 +#> 1 -1059.86 9.96874 2.01536 ``` ### Laplace Approximation @@ -171,28 +188,28 @@ summary(lapl_num) #> # A tibble: 4 × 10 #> variable mean median sd mad q5 q95 rhat ess_bulk #> -#> 1 log_p__ -1047. -1047. 1.68 1.36 -1051. -1046. 1.00 993. +#> 1 log_p__ -1060. -1060. 1.16 0.801 -1063. -1059. 0.999 1050. #> 2 log_q__ -1.04 -0.692 1.04 0.716 -3.21 -0.0582 0.999 1047. -#> 3 pars[1] 10.0 10.0 0.0899 0.0866 9.85 10.1 1.00 932. -#> 4 pars[2] 2.00 2.00 0.0670 0.0679 1.89 2.11 1.00 1051. +#> 3 pars[1] 10.0 10.0 0.0897 0.0850 9.85 10.1 1.00 931. +#> 4 pars[2] 2.00 2.00 0.0651 0.0660 1.90 2.11 1.00 1051. #> # ℹ 1 more variable: ess_tail summary(lapl_opt) #> # A tibble: 4 × 10 #> variable mean median sd mad q5 q95 rhat ess_bulk #> -#> 1 log_p__ -1047. -1046. 1.06 0.712 -1049. -1046. 0.999 1042. +#> 1 log_p__ -1060. -1060. 1.06 0.712 -1062. -1059. 0.999 1048. #> 2 log_q__ -1.04 -0.692 1.04 0.716 -3.21 -0.0582 0.999 1047. -#> 3 pars[1] 10.1 10.1 0.0879 0.0838 9.96 10.2 1.00 932. -#> 4 pars[2] 1.96 1.96 0.0643 0.0651 1.86 2.07 1.00 1051. +#> 3 pars[1] 9.97 9.97 0.0903 0.0862 9.82 10.1 1.00 932. +#> 4 pars[2] 2.02 2.02 0.0661 0.0670 1.91 2.13 1.00 1051. #> # ℹ 1 more variable: ess_tail summary(lapl_est) #> # A tibble: 4 × 10 #> variable mean median sd mad q5 q95 rhat ess_bulk #> -#> 1 log_p__ -1047. -1046. 1.06 0.712 -1049. -1046. 0.999 1042. +#> 1 log_p__ -1060. -1060. 1.06 0.712 -1062. -1059. 0.999 1048. #> 2 log_q__ -1.04 -0.692 1.04 0.716 -3.21 -0.0582 0.999 1047. -#> 3 pars[1] 10.1 10.1 0.0879 0.0838 9.96 10.2 1.00 932. -#> 4 pars[2] 1.96 1.96 0.0643 0.0651 1.86 2.07 1.00 1051. +#> 3 pars[1] 9.97 9.97 0.0903 0.0862 9.82 10.1 1.00 932. +#> 4 pars[2] 2.02 2.02 0.0661 0.0670 1.91 2.13 1.00 1051. #> # ℹ 1 more variable: ess_tail ``` @@ -211,23 +228,23 @@ var_grad <- stan_variational(loglik_fun, inits, additional_args = list(y), ``` r summary(var_fd) #> # A tibble: 5 × 10 -#> variable mean median sd mad q5 q95 rhat ess_bulk -#> -#> 1 lp__ 0 0 0 0 0 0 NA NA -#> 2 log_p__ -1048. -1048. 1.83 1.68 -1051. -1046. 0.999 916. -#> 3 log_g__ -1.01 -0.713 0.994 0.740 -3.06 -0.0434 1.00 968. -#> 4 pars[1] 10.1 10.1 0.0817 0.0857 9.95 10.2 1.00 1064. -#> 5 pars[2] 2.08 2.08 0.0615 0.0624 1.99 2.19 1.00 882. +#> variable mean median sd mad q5 q95 rhat ess_bulk +#> +#> 1 lp__ 0 0 0 0 0 0 NA NA +#> 2 log_p__ -1061. -1061. 1.67 1.33 -1064. -1059. 1.00 996. +#> 3 log_g__ -0.966 -0.697 0.963 0.729 -3.03 -0.0399 1.00 1094. +#> 4 pars[1] 9.94 9.94 0.0813 0.0830 9.80 10.1 0.999 1104. +#> 5 pars[2] 2.11 2.11 0.0710 0.0692 1.99 2.22 1.00 944. #> # ℹ 1 more variable: ess_tail summary(var_grad) #> # A tibble: 5 × 10 #> variable mean median sd mad q5 q95 rhat ess_bulk #> #> 1 lp__ 0 0 0 0 0 0 NA NA -#> 2 log_p__ -1047. -1047. 1.36 1.02 -1050. -1046. 0.999 1001. +#> 2 log_p__ -1061. -1060. 1.35 1.01 -1063. -1059. 0.999 1003. #> 3 log_g__ -1.03 -0.714 1.03 0.731 -3.29 -0.0486 1.00 959. -#> 4 pars[1] 10.2 10.2 0.0811 0.0838 10.1 10.3 1.00 1012. -#> 5 pars[2] 1.95 1.95 0.0608 0.0597 1.86 2.05 1.00 850. +#> 4 pars[1] 10.1 10.1 0.0834 0.0862 9.93 10.2 1.00 1012. +#> 5 pars[2] 2.01 2.01 0.0625 0.0614 1.91 2.11 1.00 850. #> # ℹ 1 more variable: ess_tail ``` @@ -248,16 +265,16 @@ summary(path_fd) #> # A tibble: 4 × 10 #> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail #> -#> 1 lp_appr… 3.09e0 3.43e0 1.02 0.714 1.08e0 4.05e0 1.00 953. 912. -#> 2 lp__ -1.05e3 -1.05e3 0.972 0.689 -1.05e3 -1.05e3 0.999 948. 1021. -#> 3 pars[1] 1.01e1 1.01e1 0.0854 0.0801 9.97e0 1.03e1 1.00 1015. 917. -#> 4 pars[2] 1.97e0 1.97e0 0.0614 0.0620 1.87e0 2.07e0 1.00 968. 1025. +#> 1 lp_appr… 3.03e0 3.33e0 0.965 0.702 1.15e0 3.96e0 1.00 977. 1018. +#> 2 lp__ -1.06e3 -1.06e3 0.970 0.697 -1.06e3 -1.06e3 0.999 991. 1018. +#> 3 pars[1] 9.97e0 9.97e0 0.0886 0.0850 9.82e0 1.01e1 1.00 1047. 824. +#> 4 pars[2] 2.02e0 2.02e0 0.0648 0.0688 1.91e0 2.13e0 0.999 795. 793. summary(path_grad) #> # A tibble: 4 × 10 #> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail #> -#> 1 lp_appr… 3.09e0 3.43e0 1.02 0.714 1.08e0 4.05e0 1.00 953. 912. -#> 2 lp__ -1.05e3 -1.05e3 0.972 0.689 -1.05e3 -1.05e3 0.999 948. 1021. -#> 3 pars[1] 1.01e1 1.01e1 0.0854 0.0801 9.97e0 1.03e1 1.00 1015. 917. -#> 4 pars[2] 1.97e0 1.97e0 0.0614 0.0620 1.87e0 2.07e0 1.00 968. 1025. +#> 1 lp_appr… 3.03e0 3.33e0 0.965 0.702 1.15e0 3.96e0 1.00 977. 1018. +#> 2 lp__ -1.06e3 -1.06e3 0.970 0.697 -1.06e3 -1.06e3 0.999 991. 1018. +#> 3 pars[1] 9.97e0 9.97e0 0.0886 0.0850 9.82e0 1.01e1 1.00 1047. 824. +#> 4 pars[2] 2.02e0 2.02e0 0.0648 0.0688 1.91e0 2.13e0 0.999 795. 793. ```