Skip to content

Commit

Permalink
improve benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
teddygroves committed Nov 8, 2024
1 parent a076817 commit 3de08a3
Show file tree
Hide file tree
Showing 15 changed files with 885 additions and 46 deletions.
48 changes: 48 additions & 0 deletions benchmarks/analyse_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from pathlib import Path
from matplotlib import pyplot as plt
import numpy as np
import polars as pl

HERE = Path(__file__).parent
CSV_FILE_SIMPLE = HERE / "simple_example.csv"
CSV_FILE_LINEAR = HERE / "linear_pathway.csv"


def plot_comparison(name: str, df: pl.DataFrame, ax: plt.Axes):
xlow, xhigh = ax.get_xlim()
groups = df.group_by(["algorithm"])
n_group = df["algorithm"].n_unique()
x = np.linspace(xlow, xhigh, n_group)
algs = []
for xi, ((algorithm_name,), subdf) in zip(x, groups):
xs = np.linspace(xi - 0.01, xi + 0.01, len(subdf))
ax.scatter(xs, subdf["neff_per_s"])
algs.append(algorithm_name)
ax.set(
title=name,
ylabel="Effective samples per second\n(higher is better)",
xlabel="Algorithm",
)
ax.set_xticks(x, algs)
ax.set_ylim(ymin=0)
return ax


def comparison_figure(df_dict):
f, axes = plt.subplots(1, len(df_dict.keys()), figsize=[12, 5])
f.suptitle("Benchmark comparison")
for (name, df), ax in zip(df_dict.items(), axes):
ax = plot_comparison(name, df.filter(pl.col("algorithm") != "Stan"), ax)
return f, axes


def main():
df_simple = pl.read_csv(CSV_FILE_SIMPLE)
df_linear = pl.read_csv(CSV_FILE_LINEAR)
df_dict = {"Simple model": df_simple, "Michaelis Menten model": df_linear}
f, _ = comparison_figure(df_dict)
f.savefig(HERE / "figure.png", bbox_inches="tight")


if __name__ == "__main__":
main()
Binary file added benchmarks/figure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added benchmarks/linear_pathway
Binary file not shown.
31 changes: 31 additions & 0 deletions benchmarks/linear_pathway.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
algorithm,time,neff,neff_per_s
grapeNUTS,0.4110919590020785,412.7281800223967,1003.9801824980719
NUTS,1.58320029200695,418.88989401017136,264.5842703067999
Stan,5.421014916995773,545.005411550966,100.53567826243835
grapeNUTS,0.4117540829902282,412.7281800223967,1002.3657252530307
NUTS,1.4445977910072543,418.88989401017136,289.9699117759954
Stan,5.421461499994621,545.005411550966,100.52739681938287
grapeNUTS,0.41057587500836235,412.7281800223967,1005.2421614250728
NUTS,1.4517032089934219,418.88989401017136,288.5506427313198
Stan,5.45155062500271,545.005411550966,99.97254892055506
grapeNUTS,0.4171639160049381,412.7281800223967,989.3669231389398
NUTS,1.556118750013411,418.88989401017136,269.1888996302893
Stan,5.465081374990405,545.005411550966,99.72503136825165
grapeNUTS,0.4117212919954909,412.7281800223967,1002.4455573381346
NUTS,1.4517152920016088,418.88989401017136,288.5482410484295
Stan,5.425389500000165,545.005411550966,100.45461465042268
grapeNUTS,0.4133557080058381,412.7281800223967,998.4818693166987
NUTS,1.4468039579951437,418.88989401017136,289.52774955815914
Stan,5.518105625000317,545.005411550966,98.76675957085085
grapeNUTS,0.4129722499928903,412.7281800223967,999.4089918378347
NUTS,1.597287917000358,418.88989401017136,262.250712317933
Stan,5.551175000000512,545.005411550966,98.17838773789614
grapeNUTS,0.4166600830067182,412.7281800223967,990.5632837301141
NUTS,1.500209250007174,418.88989401017136,279.2209780123461
Stan,5.495392291006283,545.005411550966,99.17497836194836
grapeNUTS,0.4184131669899216,412.7281800223967,986.4129826304874
NUTS,1.5448935000022175,418.88989401017136,271.14483555634746
Stan,5.51635812499444,545.005411550966,98.79804740768445
grapeNUTS,0.4137904169911053,412.7281800223967,997.43291065938
NUTS,1.4955780839954969,418.88989401017136,280.0856060227161
Stan,5.522869459004141,545.005411550966,98.6815668189338
104 changes: 90 additions & 14 deletions benchmarks/linear_pathway.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,35 @@
"""

import logging
import timeit
from collections import OrderedDict
from functools import partial
import timeit
from pathlib import Path

import arviz as az
import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx

import polars as pl
from blackjax import nuts
from blackjax import window_adaptation as nuts_window_adaptation
from blackjax.util import run_inference_algorithm
from cmdstanpy import CmdStanModel
from jax.scipy.stats import norm

from grapevine import run_grapenuts
from grapevine import run_grapenuts, get_idata

# Use 64 bit floats
jax.config.update("jax_enable_x64", True)


SEED = 1234
SD = 0.05
HERE = Path(__file__).parent
STAN_FILE = HERE / "linear_pathway.stan"
CSV_OUTPUT_FILE = HERE / "linear_pathway.csv"
TRUE_PARAMS = OrderedDict(
log_km=jnp.array([2.0, 3.0]),
log_vmax=jnp.array(0.0),
Expand All @@ -43,6 +51,24 @@
log_conc_ext=jnp.array([1.0, 0.0]),
)
DEFAULT_GUESS = jnp.array([0.01, 0.01])
N_WARMUP = 1000
N_SAMPLE = 1000
INIT_STEPSIZE = 0.0001
MAX_TREEDEPTH = 10
TARGET_ACCEPT = 0.95


# override timeit template:
# see https://stackoverflow.com/questions/24812253/how-can-i-capture-return-value-with-python-timeit-module
timeit.template = """
def inner(_it, _timer{init}):
{setup}
_t0 = _timer()
for _i in _it:
retval = {stmt}
_t1 = _timer()
return _t1 - _t0, retval
"""


@eqx.filter_jit
Expand Down Expand Up @@ -112,7 +138,39 @@ def simulate(key, params, guess):
)


def fit_stan(sim):
data = {
"y": sim,
"sd": SD,
"prior_log_km": [TRUE_PARAMS["log_km"].tolist(), [0.1, 0.1]],
"prior_log_vmax": [TRUE_PARAMS["log_vmax"].tolist(), 0.1],
"prior_log_keq": [TRUE_PARAMS["log_keq"].tolist(), [0.1, 0.1, 0.1]],
"prior_log_kf": [TRUE_PARAMS["log_kf"].tolist(), [0.1, 0.1]],
"prior_log_cext": [TRUE_PARAMS["log_conc_ext"].tolist(), [0.1, 0.1]],
"y_guess": DEFAULT_GUESS.tolist(),
"scaling_step": 1e-9,
"ftol": 1e-9,
"max_steps": 1000,
}
model = CmdStanModel(stan_file=STAN_FILE)
mcmc = model.sample(
data=data,
chains=1,
inits={k: v.tolist() for k, v in TRUE_PARAMS.items()},
iter_warmup=N_WARMUP,
iter_sampling=N_SAMPLE,
adapt_delta=TARGET_ACCEPT,
step_size=INIT_STEPSIZE,
max_treedepth=MAX_TREEDEPTH,
show_progress=False,
seed=SEED,
)
return mcmc


def main():
cmdstanpy_logger = logging.getLogger("cmdstanpy")
cmdstanpy_logger.disabled = True
key = jax.random.key(SEED)
key, sim_key = jax.random.split(key)
_, sim = simulate(sim_key, TRUE_PARAMS, DEFAULT_GUESS)
Expand Down Expand Up @@ -160,17 +218,35 @@ def run_nuts_example():
initial_state,
)

# run once for jitting
_ = run_grapenuts_example()
_ = run_nuts_example()

# timers
time_grapenuts = timeit.timeit(run_grapenuts_example, number=5) #  type: ignore
time_nuts = timeit.timeit(run_nuts_example, number=5) #  type: ignore

# print results
print(f"Runtime for grapenuts: {round(time_grapenuts, 4)}")
print(f"Runtime for nuts: {round(time_nuts, 4)}")
def run_stan_example():
return fit_stan(sim.tolist())

results = []

_ = run_stan_example() # run once for jitting
_ = run_grapenuts_example() # run once for jitting
_ = run_nuts_example() # run once for jitting
for _ in range(10):
_ = run_nuts_example() # one more time to be safe!
time_nuts, (_, out_nuts) = timeit.timeit(run_nuts_example, number=1)
idata_nuts = get_idata(*out_nuts)
neff_nuts = az.ess(idata_nuts.sample_stats)["energy"].item()
time_gn, out_gn = timeit.timeit(run_grapenuts_example, number=1)
idata_gn = get_idata(*out_gn)
neff_gn = az.ess(idata_gn.sample_stats)["energy"].item()
results += [{"algorithm": "grapeNUTS", "time": time_gn, "neff": neff_gn}]
results += [{"algorithm": "NUTS", "time": time_nuts, "neff": neff_nuts}]
time_stan, mcmc = timeit.timeit(run_stan_example, number=1)
idata_stan = az.from_cmdstanpy(mcmc)
neff_stan = az.ess(idata_stan.sample_stats)["lp"].item()
results += [{"algorithm": "Stan", "time": time_stan, "neff": neff_stan}]
results_df = pl.from_records(results).with_columns(
neff_per_s=pl.col("neff") / pl.col("time")
)
print(f"Benchmark results saved to {CSV_OUTPUT_FILE}")
print("Mean results:")
results_df.write_csv(CSV_OUTPUT_FILE)
print(results_df.group_by("algorithm").mean())


if __name__ == "__main__":
Expand Down
65 changes: 65 additions & 0 deletions benchmarks/linear_pathway.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
functions {
real rmm(real s, real p, real km_s, real km_p, real vmax, real keq){
real num = vmax * (s - p / keq) / km_s;
real denom = 1 + s / km_s + p / km_p;
return num / denom;
}
real ma(real s, real p, real kf, real keq){
return kf * (s - p / keq);
}
vector dcdt(vector c, vector km, real vmax, vector keq, vector kf, vector cext){
matrix[4, 3] S = [[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]]';
vector[3] v = [
ma(cext[1], c[1], kf[1], keq[1]),
rmm(c[1], c[2], km[1], km[2], vmax, keq[2]),
ma(c[1], cext[2], kf[2], keq[3])
]';
vector[2] out = (S * v)[2:3];
return out;
}
}
data {
vector[2] y;
real<lower=0> sd;
array[2] vector[2] prior_log_km;
array[2] real prior_log_vmax;
array[2] vector[3] prior_log_keq;
array[2] vector[2] prior_log_kf;
array[2] vector[2] prior_log_cext;
vector[2] y_guess;
real scaling_step;
real ftol;
int<lower=0> max_steps;
}
parameters {
vector[2] log_km;
real log_vmax;
vector[3] log_keq;
vector[2] log_kf;
vector[2] log_cext;
}
transformed parameters {
vector[2] km = exp(log_km);
real vmax = exp(log_vmax);
vector[3] keq = exp(log_keq);
vector[2] kf = exp(log_kf);
vector[2] cext = exp(log_cext);
vector[2] yhat = solve_newton_tol(dcdt,
y_guess,
scaling_step,
ftol,
max_steps,
km,
vmax,
keq,
kf,
cext);
}
model {
log_km ~ normal(prior_log_km[1], prior_log_km[2]);
log_vmax ~ normal(prior_log_vmax[1], prior_log_vmax[2]);
log_keq ~ normal(prior_log_keq[1], prior_log_keq[2]);
log_kf ~ normal(prior_log_kf[1], prior_log_kf[2]);
log_cext ~ normal(prior_log_cext[1], prior_log_cext[2]);
y ~ lognormal(log(yhat), sd);
}
Binary file added benchmarks/simple_example
Binary file not shown.
31 changes: 31 additions & 0 deletions benchmarks/simple_example.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
algorithm,time,neff,neff_per_s
grapeNUTS,0.12466304200643208,376.1319695228191,3017.18908402229
NUTS,0.9005323750025127,371.5385686486706,412.5765813223916
Stan,0.34212120799929835,458.96174733293986,1341.5179667373359
grapeNUTS,0.12385154099320062,376.1319695228191,3036.958333392626
NUTS,0.9099195839953609,371.5385686486706,408.3202243183776
Stan,0.4125580000109039,476.3393305826259,1154.5996697919716
grapeNUTS,0.12458145900745876,376.1319695228191,3019.1649104085373
NUTS,0.8996170829923358,371.5385686486706,412.9963466376682
Stan,0.41676641699450556,330.3517860118978,792.6545243117633
grapeNUTS,0.12741541699506342,376.1319695228191,2952.013016897256
NUTS,0.9082949589937925,371.5385686486706,409.05056773656474
Stan,0.2685750419914257,490.0342347403413,1824.5710066983281
grapeNUTS,0.12431900000956375,376.1319695228191,3025.538891833779
NUTS,0.8986174169986043,371.5385686486706,413.4557839860427
Stan,0.3757258749974426,328.2939618726023,873.7592583285272
grapeNUTS,0.12386395799694583,376.1319695228191,3036.6538870984045
NUTS,0.9077164170012111,371.5385686486706,409.3112801419949
Stan,0.3016477500059409,397.0605881308124,1316.3054858622097
grapeNUTS,0.12421791600354481,376.1319695228191,3028.000964950059
NUTS,0.9003118749969872,371.5385686486706,412.67762757201655
Stan,0.4210843750042841,362.02231763343286,859.7381881713129
grapeNUTS,0.12391262500023004,376.1319695228191,3035.4612334467197
NUTS,0.9079646249883808,371.5385686486706,409.19938775525003
Stan,0.42453079200640786,412.9657288082927,972.758010924351
grapeNUTS,0.12378912499116268,376.1319695228191,3038.489605202971
NUTS,0.9044461669982411,371.5385686486706,410.7912468486288
Stan,0.4308973750012228,415.0890792691663,963.3130841606736
grapeNUTS,0.12399737500527408,376.1319695228191,3033.3865495686564
NUTS,0.9110034579935018,371.5385686486706,407.83442190986807
Stan,0.3947812080004951,394.9687103098861,1000.4749524688388
Binary file added benchmarks/simple_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 3de08a3

Please sign in to comment.