Skip to content

Commit

Permalink
Further finite-diff optims
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Dec 9, 2023
1 parent 1a3cb22 commit 5ca656e
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 76 deletions.
8 changes: 5 additions & 3 deletions R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ prepare_inputs <- function(fn, par_inits, extra_args_list, grad_fun, lower, uppe
}
bounds_types <- sapply(seq_len(length(par_inits)), function(i) {
if (lower[i] != -Inf && upper[i] != Inf) {
2
} else if (lower[i] != -Inf || upper[i] != Inf) {
3
} else if (lower[i] != -Inf) {
1
} else if (upper[i] != Inf) {
2
} else {
3
4
}
})
if (is.null(output_dir)) {
Expand Down
74 changes: 37 additions & 37 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ iterations
``` r
unlist(fit@timing)
#> warmup sampling
#> 0.522 0.562
#> 0.558 0.426
summary(fit)
#> # A tibble: 3 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ -1.05e3 -1.05e3 0.955 0.741 -1.05e3 -1.05e3 1.00 524. 737.
#> 2 pars[1] 9.96e0 9.96e0 0.0890 0.0899 9.81e0 1.01e1 0.999 776. 756.
#> 3 pars[2] 1.98e0 1.98e0 0.0633 0.0632 1.88e0 2.09e0 1.00 1006. 721.
#> 1 lp__ -1.05e3 -1.05e3 1.07 0.741 -1.05e3 -1.05e3 1.01 452. 555.
#> 2 pars[1] 1.01e1 1.01e1 0.0862 0.0876 9.97e0 1.02e1 1.00 861. 648.
#> 3 pars[2] 1.97e0 1.97e0 0.0670 0.0681 1.86e0 2.08e0 1.00 973. 623.
```

Estimation time can be improved further by providing a gradient
Expand All @@ -114,14 +114,14 @@ Which shows that the estimation time was dramatically improved, now
``` r
unlist(fit_grad@timing)
#> warmup sampling
#> 0.078 0.075
#> 0.079 0.093
summary(fit_grad)
#> # A tibble: 3 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ -1.05e3 -1.05e3 0.956 0.689 -1.05e3 -1.05e3 1.00 473. 598.
#> 2 pars[1] 9.95e0 9.95e0 0.0853 0.0905 9.81e0 1.01e1 0.999 974. 743.
#> 3 pars[2] 1.98e0 1.98e0 0.0636 0.0598 1.88e0 2.09e0 1.00 1034. 592.
#> 1 lp__ -1.05e3 -1.05e3 1.11 0.764 -1.05e3 -1.05e3 1.00 421. 564.
#> 2 pars[1] 1.01e1 1.01e1 0.0927 0.100 9.95e0 1.03e1 0.999 882. 588.
#> 3 pars[2] 1.97e0 1.97e0 0.0652 0.0625 1.86e0 2.08e0 1.00 724. 591.
```

### Optimization
Expand All @@ -139,10 +139,10 @@ opt_grad <- stan_optimize(loglik_fun, inits, additional_args = list(y),
``` r
summary(opt_fd)
#> lp__ pars[1] pars[2]
#> 1 -1049.46 9.9546 1.9739
#> 1 -1046.14 10.1042 1.96079
summary(opt_grad)
#> lp__ pars[1] pars[2]
#> 1 -1049.46 9.9546 1.9739
#> 1 -1046.14 10.1042 1.96079
```

### Laplace Approximation
Expand Down Expand Up @@ -171,28 +171,28 @@ summary(lapl_num)
#> # A tibble: 4 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 log_p__ -1050. -1050. 1.20 0.890 -1052. -1049. 1.00 1027.
#> 1 log_p__ -1047. -1047. 1.68 1.36 -1051. -1046. 1.00 993.
#> 2 log_q__ -1.04 -0.692 1.04 0.716 -3.21 -0.0582 0.999 1047.
#> 3 pars[1] 10.0 10.0 0.0897 0.0854 9.85 10.1 1.00 930.
#> 4 pars[2] 2.00 2.00 0.0665 0.0673 1.90 2.11 1.00 1051.
#> 3 pars[1] 10.0 10.0 0.0899 0.0866 9.85 10.1 1.00 932.
#> 4 pars[2] 2.00 2.00 0.0670 0.0679 1.89 2.11 1.00 1051.
#> # ℹ 1 more variable: ess_tail <dbl>
summary(lapl_opt)
#> # A tibble: 4 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 log_p__ -1050. -1049. 1.06 0.712 -1052. -1049. 0.999 1045.
#> 1 log_p__ -1047. -1046. 1.06 0.712 -1049. -1046. 0.999 1042.
#> 2 log_q__ -1.04 -0.692 1.04 0.716 -3.21 -0.0582 0.999 1047.
#> 3 pars[1] 9.95 9.96 0.0885 0.0844 9.81 10.1 1.00 932.
#> 4 pars[2] 1.97 1.97 0.0647 0.0656 1.87 2.08 1.00 1051.
#> 3 pars[1] 10.1 10.1 0.0879 0.0838 9.96 10.2 1.00 932.
#> 4 pars[2] 1.96 1.96 0.0643 0.0651 1.86 2.07 1.00 1051.
#> # ℹ 1 more variable: ess_tail <dbl>
summary(lapl_est)
#> # A tibble: 4 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 log_p__ -1050. -1049. 1.06 0.712 -1052. -1049. 0.999 1045.
#> 1 log_p__ -1047. -1046. 1.06 0.712 -1049. -1046. 0.999 1042.
#> 2 log_q__ -1.04 -0.692 1.04 0.716 -3.21 -0.0582 0.999 1047.
#> 3 pars[1] 9.95 9.96 0.0885 0.0844 9.81 10.1 1.00 932.
#> 4 pars[2] 1.97 1.97 0.0647 0.0656 1.87 2.08 1.00 1051.
#> 3 pars[1] 10.1 10.1 0.0879 0.0838 9.96 10.2 1.00 932.
#> 4 pars[2] 1.96 1.96 0.0643 0.0651 1.86 2.07 1.00 1051.
#> # ℹ 1 more variable: ess_tail <dbl>
```

Expand All @@ -211,23 +211,23 @@ var_grad <- stan_variational(loglik_fun, inits, additional_args = list(y),
``` r
summary(var_fd)
#> # A tibble: 5 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ 0 0 0 0 0 0 NA NA
#> 2 log_p__ -1051. -1050. 1.67 1.33 -1054. -1049. 1.00 996.
#> 3 log_g__ -0.966 -0.697 0.963 0.729 -3.03 -0.0399 1.00 1094.
#> 4 pars[1] 9.92 9.93 0.0796 0.0813 9.79 10.0 0.999 1104.
#> 5 pars[2] 2.06 2.06 0.0696 0.0678 1.95 2.17 1.00 944.
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ 0 0 0 0 0 0 NA NA
#> 2 log_p__ -1048. -1048. 1.83 1.68 -1051. -1046. 0.999 916.
#> 3 log_g__ -1.01 -0.713 0.994 0.740 -3.06 -0.0434 1.00 968.
#> 4 pars[1] 10.1 10.1 0.0817 0.0857 9.95 10.2 1.00 1064.
#> 5 pars[2] 2.08 2.08 0.0615 0.0624 1.99 2.19 1.00 882.
#> # ℹ 1 more variable: ess_tail <dbl>
summary(var_grad)
#> # A tibble: 5 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ 0 0 0 0 0 0 NA NA
#> 2 log_p__ -1050. -1050. 1.36 1.02 -1053. -1049. 0.999 1003.
#> 2 log_p__ -1047. -1047. 1.36 1.02 -1050. -1046. 0.999 1001.
#> 3 log_g__ -1.03 -0.714 1.03 0.731 -3.29 -0.0486 1.00 959.
#> 4 pars[1] 10.0 10.0 0.0817 0.0844 9.91 10.2 1.00 1012.
#> 5 pars[2] 1.97 1.96 0.0612 0.0601 1.87 2.07 1.00 850.
#> 4 pars[1] 10.2 10.2 0.0811 0.0838 10.1 10.3 1.00 1012.
#> 5 pars[2] 1.95 1.95 0.0608 0.0597 1.86 2.05 1.00 850.
#> # ℹ 1 more variable: ess_tail <dbl>
```

Expand All @@ -248,16 +248,16 @@ summary(path_fd)
#> # A tibble: 4 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp_appr… 3.11e0 3.40e0 0.975 0.716 1.20e0 4.04e0 1.00 1053. 981.
#> 2 lp__ -1.05e3 -1.05e3 0.947 0.682 -1.05e3 -1.05e3 0.999 1044. 956.
#> 3 pars[1] 9.95e0 9.96e0 0.0830 0.0771 9.81e0 1.01e1 1.00 1030. 972.
#> 4 pars[2] 1.98e0 1.98e0 0.0633 0.0671 1.87e0 2.08e0 0.999 760. 809.
#> 1 lp_appr… 3.09e0 3.43e0 1.02 0.714 1.08e0 4.05e0 1.00 953. 912.
#> 2 lp__ -1.05e3 -1.05e3 0.972 0.689 -1.05e3 -1.05e3 0.999 948. 1021.
#> 3 pars[1] 1.01e1 1.01e1 0.0854 0.0801 9.97e0 1.03e1 1.00 1015. 917.
#> 4 pars[2] 1.97e0 1.97e0 0.0614 0.0620 1.87e0 2.07e0 1.00 968. 1025.
summary(path_grad)
#> # A tibble: 4 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp_appr… 3.11e0 3.40e0 0.975 0.716 1.20e0 4.04e0 1.00 1053. 981.
#> 2 lp__ -1.05e3 -1.05e3 0.947 0.682 -1.05e3 -1.05e3 0.999 1044. 956.
#> 3 pars[1] 9.95e0 9.96e0 0.0830 0.0771 9.81e0 1.01e1 1.00 1030. 972.
#> 4 pars[2] 1.98e0 1.98e0 0.0633 0.0671 1.87e0 2.08e0 0.999 760. 809.
#> 1 lp_appr… 3.09e0 3.43e0 1.02 0.714 1.08e0 4.05e0 1.00 953. 912.
#> 2 lp__ -1.05e3 -1.05e3 0.972 0.689 -1.05e3 -1.05e3 0.999 948. 1021.
#> 3 pars[1] 1.01e1 1.01e1 0.0854 0.0801 9.97e0 1.03e1 1.00 1015. 917.
#> 4 pars[2] 1.97e0 1.97e0 0.0614 0.0620 1.87e0 2.07e0 1.00 968. 1025.
```
74 changes: 38 additions & 36 deletions inst/include/estimator/estimator_ext_header.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,7 @@ namespace internal {
Rcpp::Function grad_fun("ls");
}

enum boundsType {
SINGLE = 1,
BOTH = 2,
NONE = 3
};

inline double single_b_step(double x, double lb, double hs) {
return std::exp(hs) * (x - lb) + lb;
}

inline double both_b_step(double x, double lb, double ub, double hs) {
return lb + (ub - lb) / (1 + (std::exp(-hs) * x - ub) / (lb - x));
}

inline double fdiff_step(int type, double x, double lb, double ub, double hs) {
switch(type) {
case SINGLE:
return single_b_step(x, lb, hs);
case BOTH:
return both_b_step(x, lb, ub, hs);
case NONE:
return x + hs;
}
return stan::math::NOT_A_NUMBER;
}
enum boundsType { LOWER = 1, UPPER = 2, BOTH = 3, NONE = 4 };

template <typename F, typename T>
Eigen::VectorXd fdiff(const F& f, const T& x,
Expand All @@ -46,16 +22,41 @@ Eigen::VectorXd fdiff(const F& f, const T& x,
return Eigen::VectorXd::NullaryExpr(x.size(), [&f, &x, &x_temp, &cons_type, &lower, &upper](Eigen::Index i) {
double h = stan::math::finite_diff_stepsize(x[i]);
double delta_f = 0;
for (int j = 0; j < 6; ++j) {
x_temp[i] = fdiff_step(cons_type[i], x[i], lower[i], upper[i], h * h_scale[j]);
delta_f += f(x_temp) * mults[j];
double scal = 0;
switch (cons_type[i]) {
case LOWER:
scal = x[i] - lower[i];
for (int j = 0; j < 6; ++j) {
x_temp[i] = lower[i] + std::exp(h * h_scale[j]) * scal;
delta_f += f(x_temp) * mults[j];
}
break;
case UPPER:
scal = x[i] - upper[i];
for (int j = 0; j < 6; ++j) {
x_temp[i] = upper[i] + std::exp(h * h_scale[j]) * scal;
delta_f += f(x_temp) * mults[j];
}
break;
case BOTH:
scal = (x[i] - upper[i]) / (lower[i] - x[i]);
for (int j = 0; j < 6; ++j) {
x_temp[i] = 1 / (1 + std::exp(-h * h_scale[j]) * scal);
delta_f += f(x_temp) * mults[j];
}
break;
case NONE:
for (int j = 0; j < 6; ++j) {
x_temp[i] = x[i] + h * h_scale[j];
delta_f += f(x_temp) * mults[j];
}
break;
}
x_temp[i] = x[i];
return delta_f / (60 * h * (cons_type[i] == 3 ? 1 : x[i]));
return delta_f / (60 * h * (cons_type[i] == NONE ? 1 : x[i]));
});
}


template <typename T, typename TLower, typename TUpper,
stan::require_st_arithmetic<T>* = nullptr>
double r_function(const T& v, int finite_diff, int no_bounds,
Expand All @@ -76,13 +77,14 @@ stan::math::var r_function(const T& v, int finite_diff, int no_bounds,

stan::arena_t<stan::plain_type_t<T>> arena_v = v;
if (finite_diff == 1) {
stan::arena_t<Eigen::VectorXd> arena_grad = fdiff(
[&](const auto& x) { return Rcpp::as<double>(internal::ll_fun(x)); },
v.val(), bounds_types, lower_bounds, upper_bounds);
stan::arena_t<Eigen::VectorXd> arena_grad =
fdiff([&](const auto& x) { return Rcpp::as<double>(internal::ll_fun(x)); },
v.val(), bounds_types, lower_bounds, upper_bounds);
return make_callback_var(
Rcpp::as<double>(internal::ll_fun(v.val())), [arena_v, arena_grad](auto& vi) mutable {
arena_v.adj() += vi.adj() * arena_grad;
});
Rcpp::as<double>(internal::ll_fun(v.val())),
[arena_v, arena_grad](auto& vi) mutable {
arena_v.adj() += vi.adj() * arena_grad;
});
} else {
return make_callback_var(
Rcpp::as<double>(internal::ll_fun(v.val())),
Expand Down

0 comments on commit 5ca656e

Please sign in to comment.