From 5023e0f82066ca34dbd7fb999c2e854a8b32db2f Mon Sep 17 00:00:00 2001 From: teddygroves Date: Thu, 24 Oct 2024 15:38:06 +0200 Subject: [PATCH] Improve benchmarks --- benchmarks/linear_pathway.py | 177 ++++++++++++++++++++++++++++ benchmarks/optimistix_example.py | 191 ------------------------------- benchmarks/simple_example.py | 132 +++++++++++++++++++++ src/grapevine/util.py | 2 + 4 files changed, 311 insertions(+), 191 deletions(-) create mode 100644 benchmarks/linear_pathway.py delete mode 100644 benchmarks/optimistix_example.py create mode 100644 benchmarks/simple_example.py diff --git a/benchmarks/linear_pathway.py b/benchmarks/linear_pathway.py new file mode 100644 index 0000000..a7aacad --- /dev/null +++ b/benchmarks/linear_pathway.py @@ -0,0 +1,177 @@ +"""An example comparing GrapeNUTS and NUTS on a representative problem. + +The problem is a steady kinetic model of a linear pathway with this structure: + + Aext <-r1-> Aint <-r2-> Bint <-r3-> Bext + +Reactions r1 and r3 behave according to the law of mass action, and reaction r2 according to the Michaelis Menten rate law. We assume we have measurements of Aint and Bint, as well as plenty of information about all the kinetic parameters and boundary conditions, and that the pathway is in a steady state, so that the concentrations c_m1_int and c_m2_int are not changing. + +To formulate this situation as a statistical modelling problem, there are two functions `rmm` and `ma` that specify rate laws, and another function `fn` that specifies a steady state problem, i.e. finding values for c_m1_int and c_m2_int that put the system in a steady state. + +We can then specify joint and posterior log density functions in terms of log scale parameters, which we can sample using GrapeNUTS. + +The benchmark proceeds by first choosing some true parameter values (see dictionary `TRUE_PARAMS`), and then simulating some measurements of c_m1_int and c_m2_int using these parameters: see function `simulate` for how this works. Then the log posterior is sampled using NUTS and GrapeNUTS, and the times are printed. + +""" + +from collections import OrderedDict +from functools import partial +import timeit + +import equinox as eqx +import jax +import jax.numpy as jnp +import optimistix as optx + +from blackjax import nuts +from blackjax import window_adaptation as nuts_window_adaptation +from blackjax.util import run_inference_algorithm +from jax.scipy.stats import norm + +from grapevine import run_grapenuts + +# Use 64 bit floats +jax.config.update("jax_enable_x64", True) + +SEED = 1234 +SD = 0.05 +TRUE_PARAMS = OrderedDict( + log_km=jnp.array([2.0, 3.0]), + log_vmax=jnp.array(0.0), + log_keq=jnp.array([1.0, 1.0, 1.0]), + log_kf=jnp.array([1.0, -1.0]), + log_conc_ext=jnp.array([1.0, 0.0]), +) +DEFAULT_GUESS = jnp.array([0.01, 0.01]) + + +@eqx.filter_jit +def rmm(s, p, km_s, km_p, vmax, k_eq): + """Reversible Michaelis Menten rate law""" + num = vmax * (s - p / k_eq) / km_s + denom = 1 + s / km_s + p / km_p + return num / denom + + +@eqx.filter_jit +def ma(s, p, kf, keq): + """Mass action rate law""" + return kf * (s - p / keq) + + +@eqx.filter_jit +def fn(y, args): + S = jnp.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]]).transpose() + c_m1_int, c_m2_int = y + km, vmax, keq, kf, conc_ext = map(jnp.exp, args.values()) + keq_r1, keq_r2, keq_r3 = keq + kf_r1, kf_r3 = kf + c_m1_ext, c_m2_ext = conc_ext + km_m1, km_m2 = km + v = jnp.array( + [ + ma(c_m1_ext, c_m1_int, kf_r1, keq_r1), + rmm(c_m1_int, c_m2_int, km_m1, km_m2, vmax, keq_r2), + ma(c_m2_int, c_m2_ext, kf_r3, keq_r3), + ] + ) + return (S @ v)[jnp.array([1, 2])] + + +solver = optx.Newton(rtol=1e-9, atol=1e-9) + + +@eqx.filter_jit +def joint_logdensity_grapenuts(params, obs, guess): + sol = optx.root_find(fn, solver, guess, args=params) + log_km, log_vmax, log_keq, log_kf, log_conc_ext = params.values() + log_prior = jnp.sum( + norm.logpdf(log_km, loc=TRUE_PARAMS["log_km"], scale=0.1).sum() + + norm.logpdf(log_vmax, loc=TRUE_PARAMS["log_vmax"], scale=0.1).sum() + + norm.logpdf(log_keq, loc=TRUE_PARAMS["log_keq"], scale=0.1).sum() + + norm.logpdf(log_kf, loc=TRUE_PARAMS["log_kf"], scale=0.1).sum() + + norm.logpdf(log_conc_ext, loc=TRUE_PARAMS["log_conc_ext"], scale=0.1).sum() + ) + log_likelihood = norm.logpdf( + jnp.log(obs), loc=jnp.log(sol.value), scale=jnp.full(obs.shape, SD) + ).sum() + return log_prior + log_likelihood, sol.value + + +@eqx.filter_jit +def joint_logdensity_nuts(params, obs): + ld, _ = joint_logdensity_grapenuts(params, obs, DEFAULT_GUESS) + return ld + + +@eqx.filter_jit +def simulate(key, params, guess): + sol = optx.root_find(fn, solver, guess, args=params) + return sol.value, jnp.exp( + jnp.log(sol.value) + jax.random.normal(key, shape=sol.value.shape) * SD + ) + + +def main(): + key = jax.random.key(SEED) + key, sim_key = jax.random.split(key) + _, sim = simulate(sim_key, TRUE_PARAMS, DEFAULT_GUESS) + posterior_logdensity_gn = partial(joint_logdensity_grapenuts, obs=sim) + posterior_logdensity_nuts = partial(joint_logdensity_nuts, obs=sim) + key, grapenuts_key = jax.random.split(key) + key, nuts_key_warmup = jax.random.split(key) + key, nuts_key_sampling = jax.random.split(key) + + def run_grapenuts_example(): + return run_grapenuts( + posterior_logdensity_gn, + grapenuts_key, + init_parameters=TRUE_PARAMS, + default_guess=DEFAULT_GUESS, + num_warmup=1000, + num_samples=1000, + initial_step_size=0.0001, + max_num_doublings=10, + is_mass_matrix_diagonal=False, + target_acceptance_rate=0.95, + progress_bar=False, + ) + + def run_nuts_example(): + warmup = nuts_window_adaptation( + nuts, + posterior_logdensity_nuts, + progress_bar=False, + initial_step_size=0.0001, + max_num_doublings=10, + is_mass_matrix_diagonal=False, + target_acceptance_rate=0.95, + ) + (initial_state, tuned_parameters), _ = warmup.run( + nuts_key_warmup, + TRUE_PARAMS, + num_steps=1000, #  type: ignore + ) + kernel = nuts(posterior_logdensity_nuts, **tuned_parameters) + return run_inference_algorithm( + nuts_key_sampling, + kernel, + 1000, + initial_state, + ) + + # run once for jitting + _ = run_grapenuts_example() + _ = run_nuts_example() + + # timers + time_grapenuts = timeit.timeit(run_grapenuts_example, number=5) #  type: ignore + time_nuts = timeit.timeit(run_nuts_example, number=5) #  type: ignore + + # print results + print(f"Runtime for grapenuts: {round(time_grapenuts, 4)}") + print(f"Runtime for nuts: {round(time_nuts, 4)}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/optimistix_example.py b/benchmarks/optimistix_example.py deleted file mode 100644 index ff16466..0000000 --- a/benchmarks/optimistix_example.py +++ /dev/null @@ -1,191 +0,0 @@ -"""An example comparing GrapeNUTS and NUTS on a representative problem. - -This is supposed to be a complete example, mirroring how the grapevine method is used in practice. - - -""" - -from functools import partial -import timeit - -from blackjax.util import run_inference_algorithm -from grapevine.grapenuts import GrapeNUTSState -import jax -import jax.numpy as jnp -import optimistix as optx - -from blackjax import nuts -from blackjax import window_adaptation as nuts_window_adaptation -from jax.scipy.stats import norm - -from grapevine import run_grapenuts - -# Use 64 bit floats -jax.config.update("jax_enable_x64", True) - -SEED = 1234 -SD = 0.05 -TRUE_PARAMS = { - "log_km_s": 2.0, - "log_km_p": 3.0, - "log_vmax": -1.0, - "log_k_eq": 5.0, - "log_s1": 2.0, - "log_s2": 2.9, - "log_s3": 0.9, - "log_s4": 0.1, -} -TRUE_PARAMS_ARR = jnp.array(list(TRUE_PARAMS.values())) -DEFAULT_GUESS = jnp.array([0.01, 0.01, 0.01, 0.01]) - -# hack the timeit module to not destroy the timed function's return value -timeit.template = """ -def inner(_it, _timer{init}): - {setup} - _t0 = _timer() - for _i in _it: - retval = {stmt} - _t1 = _timer() - return _t1 - _t0, retval -""" - - -def rmm(p, km_s, km_p, vmax, k_eq, s): - num = vmax * (s - p / k_eq) / km_s - denom = 1 + s / km_s + p / km_p - return num / denom - - -def fn(y, args): - p1, p2, p3, p4 = y - km_s, km_p, vmax, k_eq, s1, s2, s3, s4 = args - v1 = rmm(p1, km_s, km_p, vmax, k_eq, s1) - v2 = rmm(p2, km_s, km_p, vmax, k_eq, s2) - v3 = rmm(p3, km_s, km_p, vmax, k_eq, s3) - v4 = rmm(p4, km_s, km_p, vmax, k_eq, s4) - return jnp.array([v1, v2, v3, v4]) - - -solver = optx.Newton(rtol=1e-8, atol=1e-8) - - -def grapenuts_state_from_nuts_state(nuts_state, guess): - position, logdensity, logdensity_grad = nuts_state - return GrapeNUTSState(position, logdensity, logdensity_grad, guess) - - -def joint_logdensity_grapenuts(params, obs, guess): - sol = optx.root_find(fn, solver, guess, args=jnp.exp(params)) - log_prior = norm.logpdf( - params, - loc=TRUE_PARAMS_ARR, - scale=jnp.full(params.shape, 1), - ).sum() - log_likelihood = norm.logpdf( - jnp.log(obs), loc=jnp.log(sol.value), scale=jnp.full(obs.shape, SD) - ).sum() - return log_prior + log_likelihood, sol.value - - -def joint_logdensity_nuts(params, obs): - sol = optx.root_find(fn, solver, DEFAULT_GUESS, args=jnp.exp(params)) - log_prior = norm.logpdf( - params, - loc=TRUE_PARAMS_ARR, - scale=jnp.full(params.shape, 1), - ).sum() - log_likelihood = norm.logpdf( - jnp.log(obs), loc=jnp.log(sol.value), scale=jnp.full(obs.shape, SD) - ).sum() - return log_prior + log_likelihood - - -def simulate(key, params, guess): - sol = optx.root_find(fn, solver, guess, args=jnp.exp(params)) - return sol.value, jnp.exp( - jnp.log(sol.value) + jax.random.normal(key, shape=sol.value.shape) * SD - ) - - -def main(): - key = jax.random.key(SEED) - key, sim_key = jax.random.split(key) - true_p, sim = simulate(sim_key, TRUE_PARAMS_ARR, DEFAULT_GUESS) - print("True substrate concs: " + str(true_p)) - print("Simulated measurements: " + str(sim)) - posterior_logdensity_gn = partial(joint_logdensity_grapenuts, obs=sim) - posterior_logdensity_nuts = partial(joint_logdensity_nuts, obs=sim) - key, grapenuts_key = jax.random.split(key) - key, nuts_key_warmup = jax.random.split(key) - key, nuts_key_sampling = jax.random.split(key) - - def run_grapenuts_example(): - return run_grapenuts( - posterior_logdensity_gn, - key, - init_parameters=(TRUE_PARAMS_ARR), - default_guess=DEFAULT_GUESS, - num_warmup=1000, - num_samples=1000, - initial_step_size=0.0001, - max_num_doublings=10, - is_mass_matrix_diagonal=False, - target_acceptance_rate=0.95, - progress_bar=False, - ) - - def run_nuts_example(): - warmup = nuts_window_adaptation( - nuts, - posterior_logdensity_nuts, - progress_bar=False, - initial_step_size=0.0001, - max_num_doublings=10, - is_mass_matrix_diagonal=False, - target_acceptance_rate=0.95, - ) - (initial_state, tuned_parameters), _ = warmup.run( - nuts_key_warmup, - TRUE_PARAMS_ARR, - num_steps=1000, #  type: ignore - ) - kernel = nuts(posterior_logdensity_nuts, **tuned_parameters) - return run_inference_algorithm( - nuts_key_sampling, - kernel, - 1000, - initial_state, - ) - - time_grapenuts, (state_grapenuts, _) = timeit.timeit( - run_grapenuts_example, - number=1, - ) #  type: ignore - time_nuts, (_, (state_nuts, _)) = timeit.timeit( - run_nuts_example, - number=1, - ) #  type: ignore - __import__("pdb").set_trace() - print("True param vals: " + str(TRUE_PARAMS_ARR)) - print("GrapeNUTS quantiles:") - print( - jnp.quantile( - state_grapenuts.position, - jnp.array([0.01, 0.5, 0.99]), - axis=0, - ).round(4) - ) - print("NUTS quantiles:") - print( - jnp.quantile( - state_nuts.position, - jnp.array([0.01, 0.5, 0.99]), - axis=0, - ).round(4) - ) - print(f"Runtime for grapenuts: {round(time_grapenuts, 4)}") - print(f"Runtime for nuts: {round(time_nuts, 4)}") - - -if __name__ == "__main__": - main() diff --git a/benchmarks/simple_example.py b/benchmarks/simple_example.py new file mode 100644 index 0000000..e95b96b --- /dev/null +++ b/benchmarks/simple_example.py @@ -0,0 +1,132 @@ +"""An example comparing GrapeNUTS and NUTS on a simple problem. + +The problem is taken from the Stan documentation: + +To formulate this situation as a statistical modelling problem, there is a function `fn` that takes in a state (`y`) and some parameters (`args`) and returns the quantities that should be zero. + +We can then specify joint and posterior log density functions in terms of log scale parameters, which we can sample using GrapeNUTS. + +The benchmark proceeds by first choosing some true parameter values (see dictionary `TRUE_PARAMS`), and then simulating some measurements of c_m1_int and c_m2_int using these parameters: see function `simulate` for how this works. Then the log posterior is sampled using NUTS and GrapeNUTS, and the times are printed. + +""" + +from collections import OrderedDict +from functools import partial +import timeit + +import equinox as eqx +import jax +import jax.numpy as jnp +import optimistix as optx + +from blackjax import nuts +from blackjax import window_adaptation as nuts_window_adaptation +from blackjax.util import run_inference_algorithm +from jax.scipy.stats import norm + +from grapevine import run_grapenuts + +# Use 64 bit floats +jax.config.update("jax_enable_x64", True) + +SEED = 1234 +SD = 0.05 +TRUE_PARAMS = OrderedDict(theta=jnp.array([3.0, 6.0])) +DEFAULT_GUESS = jnp.array([1.0, 1.0]) + + +@eqx.filter_jit +def fn(y, args): + y1, y2 = y + theta1, theta2 = args + return jnp.array([y1 - theta1, y1 * y2 - theta2]) + + +solver = optx.Newton(rtol=1e-9, atol=1e-9) + + +@eqx.filter_jit +def joint_logdensity_grapenuts(params, obs, guess): + theta = params["theta"] + sol = optx.root_find(fn, solver, guess, args=theta) + log_prior = jnp.sum(norm.logpdf(theta, loc=TRUE_PARAMS["theta"], scale=0.1).sum()) + log_likelihood = norm.logpdf(obs, loc=sol.value, scale=SD).sum() + return log_prior + log_likelihood, sol.value + + +@eqx.filter_jit +def joint_logdensity_nuts(params, obs): + ld, _ = joint_logdensity_grapenuts(params, obs, DEFAULT_GUESS) + return ld + + +@eqx.filter_jit +def simulate(key, params, guess): + theta = params["theta"] + sol = optx.root_find(fn, solver, guess, args=theta) + return sol.value, jnp.exp( + jnp.log(sol.value) + jax.random.normal(key, shape=sol.value.shape) * SD + ) + + +def main(): + key = jax.random.key(SEED) + key, sim_key = jax.random.split(key) + _, sim = simulate(sim_key, TRUE_PARAMS, DEFAULT_GUESS) + posterior_logdensity_gn = partial(joint_logdensity_grapenuts, obs=sim) + posterior_logdensity_nuts = partial(joint_logdensity_nuts, obs=sim) + key, grapenuts_key = jax.random.split(key) + key, nuts_key_warmup = jax.random.split(key) + key, nuts_key_sampling = jax.random.split(key) + + def run_grapenuts_example(): + return run_grapenuts( + posterior_logdensity_gn, + grapenuts_key, + init_parameters=TRUE_PARAMS, + default_guess=DEFAULT_GUESS, + num_warmup=1000, + num_samples=1000, + initial_step_size=0.0001, + max_num_doublings=10, + is_mass_matrix_diagonal=False, + target_acceptance_rate=0.95, + progress_bar=False, + ) + + def run_nuts_example(): + warmup = nuts_window_adaptation( + nuts, + posterior_logdensity_nuts, + progress_bar=False, + initial_step_size=0.0001, + max_num_doublings=10, + is_mass_matrix_diagonal=False, + target_acceptance_rate=0.95, + ) + (initial_state, tuned_parameters), _ = warmup.run( + nuts_key_warmup, + TRUE_PARAMS, + num_steps=1000, #  type: ignore + ) + kernel = nuts(posterior_logdensity_nuts, **tuned_parameters) + return run_inference_algorithm( + nuts_key_sampling, + kernel, + 1000, + initial_state, + ) + + # timers + _ = run_grapenuts_example() # run once for jitting + time_grapenuts = timeit.timeit(run_grapenuts_example, number=5) #  type: ignore + _ = run_nuts_example() # run once for jitting + time_nuts = timeit.timeit(run_nuts_example, number=5) #  type: ignore + + # print results + print(f"Runtime for grapenuts: {round(time_grapenuts, 4)}") + print(f"Runtime for nuts: {round(time_nuts, 4)}") + + +if __name__ == "__main__": + main() diff --git a/src/grapevine/util.py b/src/grapevine/util.py index 0e84e52..479119f 100644 --- a/src/grapevine/util.py +++ b/src/grapevine/util.py @@ -2,6 +2,7 @@ from typing import Callable, TypedDict, Unpack +import equinox as eqx import jax from blackjax.types import ArrayTree @@ -21,6 +22,7 @@ class AdaptationKwargs(TypedDict): target_acceptance_rate: float +@eqx.filter_jit def run_grapenuts( logdensity_fn: Callable, rng_key: KeyArray,