diff --git a/benchmarks/analyse_results.py b/benchmarks/analyse_results.py new file mode 100644 index 0000000..0f99c93 --- /dev/null +++ b/benchmarks/analyse_results.py @@ -0,0 +1,48 @@ +from pathlib import Path +from matplotlib import pyplot as plt +import numpy as np +import polars as pl + +HERE = Path(__file__).parent +CSV_FILE_SIMPLE = HERE / "simple_example.csv" +CSV_FILE_LINEAR = HERE / "linear_pathway.csv" + + +def plot_comparison(name: str, df: pl.DataFrame, ax: plt.Axes): + xlow, xhigh = ax.get_xlim() + groups = df.group_by(["algorithm"]) + n_group = df["algorithm"].n_unique() + x = np.linspace(xlow, xhigh, n_group) + algs = [] + for xi, ((algorithm_name,), subdf) in zip(x, groups): + xs = np.linspace(xi - 0.01, xi + 0.01, len(subdf)) + ax.scatter(xs, subdf["neff_per_s"]) + algs.append(algorithm_name) + ax.set( + title=name, + ylabel="Effective samples per second\n(higher is better)", + xlabel="Algorithm", + ) + ax.set_xticks(x, algs) + ax.set_ylim(ymin=0) + return ax + + +def comparison_figure(df_dict): + f, axes = plt.subplots(1, len(df_dict.keys()), figsize=[12, 5]) + f.suptitle("Benchmark comparison") + for (name, df), ax in zip(df_dict.items(), axes): + ax = plot_comparison(name, df.filter(pl.col("algorithm") != "Stan"), ax) + return f, axes + + +def main(): + df_simple = pl.read_csv(CSV_FILE_SIMPLE) + df_linear = pl.read_csv(CSV_FILE_LINEAR) + df_dict = {"Simple model": df_simple, "Michaelis Menten model": df_linear} + f, _ = comparison_figure(df_dict) + f.savefig(HERE / "figure.png", bbox_inches="tight") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/figure.png b/benchmarks/figure.png new file mode 100644 index 0000000..93d35b7 Binary files /dev/null and b/benchmarks/figure.png differ diff --git a/benchmarks/linear_pathway b/benchmarks/linear_pathway new file mode 100755 index 0000000..cabd3f1 Binary files /dev/null and b/benchmarks/linear_pathway differ diff --git a/benchmarks/linear_pathway.csv b/benchmarks/linear_pathway.csv new file mode 100644 index 0000000..d4377b4 --- /dev/null +++ b/benchmarks/linear_pathway.csv @@ -0,0 +1,31 @@ +algorithm,time,neff,neff_per_s +grapeNUTS,0.4110919590020785,412.7281800223967,1003.9801824980719 +NUTS,1.58320029200695,418.88989401017136,264.5842703067999 +Stan,5.421014916995773,545.005411550966,100.53567826243835 +grapeNUTS,0.4117540829902282,412.7281800223967,1002.3657252530307 +NUTS,1.4445977910072543,418.88989401017136,289.9699117759954 +Stan,5.421461499994621,545.005411550966,100.52739681938287 +grapeNUTS,0.41057587500836235,412.7281800223967,1005.2421614250728 +NUTS,1.4517032089934219,418.88989401017136,288.5506427313198 +Stan,5.45155062500271,545.005411550966,99.97254892055506 +grapeNUTS,0.4171639160049381,412.7281800223967,989.3669231389398 +NUTS,1.556118750013411,418.88989401017136,269.1888996302893 +Stan,5.465081374990405,545.005411550966,99.72503136825165 +grapeNUTS,0.4117212919954909,412.7281800223967,1002.4455573381346 +NUTS,1.4517152920016088,418.88989401017136,288.5482410484295 +Stan,5.425389500000165,545.005411550966,100.45461465042268 +grapeNUTS,0.4133557080058381,412.7281800223967,998.4818693166987 +NUTS,1.4468039579951437,418.88989401017136,289.52774955815914 +Stan,5.518105625000317,545.005411550966,98.76675957085085 +grapeNUTS,0.4129722499928903,412.7281800223967,999.4089918378347 +NUTS,1.597287917000358,418.88989401017136,262.250712317933 +Stan,5.551175000000512,545.005411550966,98.17838773789614 +grapeNUTS,0.4166600830067182,412.7281800223967,990.5632837301141 +NUTS,1.500209250007174,418.88989401017136,279.2209780123461 +Stan,5.495392291006283,545.005411550966,99.17497836194836 +grapeNUTS,0.4184131669899216,412.7281800223967,986.4129826304874 +NUTS,1.5448935000022175,418.88989401017136,271.14483555634746 +Stan,5.51635812499444,545.005411550966,98.79804740768445 +grapeNUTS,0.4137904169911053,412.7281800223967,997.43291065938 +NUTS,1.4955780839954969,418.88989401017136,280.0856060227161 +Stan,5.522869459004141,545.005411550966,98.6815668189338 diff --git a/benchmarks/linear_pathway.py b/benchmarks/linear_pathway.py index c9a014c..a01edc5 100644 --- a/benchmarks/linear_pathway.py +++ b/benchmarks/linear_pathway.py @@ -14,27 +14,35 @@ """ +import logging +import timeit from collections import OrderedDict from functools import partial -import timeit +from pathlib import Path +import arviz as az import equinox as eqx import jax import jax.numpy as jnp import optimistix as optx - +import polars as pl from blackjax import nuts from blackjax import window_adaptation as nuts_window_adaptation from blackjax.util import run_inference_algorithm +from cmdstanpy import CmdStanModel from jax.scipy.stats import norm -from grapevine import run_grapenuts +from grapevine import run_grapenuts, get_idata # Use 64 bit floats jax.config.update("jax_enable_x64", True) + SEED = 1234 SD = 0.05 +HERE = Path(__file__).parent +STAN_FILE = HERE / "linear_pathway.stan" +CSV_OUTPUT_FILE = HERE / "linear_pathway.csv" TRUE_PARAMS = OrderedDict( log_km=jnp.array([2.0, 3.0]), log_vmax=jnp.array(0.0), @@ -43,6 +51,24 @@ log_conc_ext=jnp.array([1.0, 0.0]), ) DEFAULT_GUESS = jnp.array([0.01, 0.01]) +N_WARMUP = 1000 +N_SAMPLE = 1000 +INIT_STEPSIZE = 0.0001 +MAX_TREEDEPTH = 10 +TARGET_ACCEPT = 0.95 + + +# override timeit template: +# see https://stackoverflow.com/questions/24812253/how-can-i-capture-return-value-with-python-timeit-module +timeit.template = """ +def inner(_it, _timer{init}): + {setup} + _t0 = _timer() + for _i in _it: + retval = {stmt} + _t1 = _timer() + return _t1 - _t0, retval +""" @eqx.filter_jit @@ -112,7 +138,39 @@ def simulate(key, params, guess): ) +def fit_stan(sim): + data = { + "y": sim, + "sd": SD, + "prior_log_km": [TRUE_PARAMS["log_km"].tolist(), [0.1, 0.1]], + "prior_log_vmax": [TRUE_PARAMS["log_vmax"].tolist(), 0.1], + "prior_log_keq": [TRUE_PARAMS["log_keq"].tolist(), [0.1, 0.1, 0.1]], + "prior_log_kf": [TRUE_PARAMS["log_kf"].tolist(), [0.1, 0.1]], + "prior_log_cext": [TRUE_PARAMS["log_conc_ext"].tolist(), [0.1, 0.1]], + "y_guess": DEFAULT_GUESS.tolist(), + "scaling_step": 1e-9, + "ftol": 1e-9, + "max_steps": 1000, + } + model = CmdStanModel(stan_file=STAN_FILE) + mcmc = model.sample( + data=data, + chains=1, + inits={k: v.tolist() for k, v in TRUE_PARAMS.items()}, + iter_warmup=N_WARMUP, + iter_sampling=N_SAMPLE, + adapt_delta=TARGET_ACCEPT, + step_size=INIT_STEPSIZE, + max_treedepth=MAX_TREEDEPTH, + show_progress=False, + seed=SEED, + ) + return mcmc + + def main(): + cmdstanpy_logger = logging.getLogger("cmdstanpy") + cmdstanpy_logger.disabled = True key = jax.random.key(SEED) key, sim_key = jax.random.split(key) _, sim = simulate(sim_key, TRUE_PARAMS, DEFAULT_GUESS) @@ -160,17 +218,35 @@ def run_nuts_example(): initial_state, ) - # run once for jitting - _ = run_grapenuts_example() - _ = run_nuts_example() - - # timers - time_grapenuts = timeit.timeit(run_grapenuts_example, number=5) #  type: ignore - time_nuts = timeit.timeit(run_nuts_example, number=5) #  type: ignore - - # print results - print(f"Runtime for grapenuts: {round(time_grapenuts, 4)}") - print(f"Runtime for nuts: {round(time_nuts, 4)}") + def run_stan_example(): + return fit_stan(sim.tolist()) + + results = [] + + _ = run_stan_example() # run once for jitting + _ = run_grapenuts_example() # run once for jitting + _ = run_nuts_example() # run once for jitting + for _ in range(10): + _ = run_nuts_example() # one more time to be safe! + time_nuts, (_, out_nuts) = timeit.timeit(run_nuts_example, number=1) + idata_nuts = get_idata(*out_nuts) + neff_nuts = az.ess(idata_nuts.sample_stats)["energy"].item() + time_gn, out_gn = timeit.timeit(run_grapenuts_example, number=1) + idata_gn = get_idata(*out_gn) + neff_gn = az.ess(idata_gn.sample_stats)["energy"].item() + results += [{"algorithm": "grapeNUTS", "time": time_gn, "neff": neff_gn}] + results += [{"algorithm": "NUTS", "time": time_nuts, "neff": neff_nuts}] + time_stan, mcmc = timeit.timeit(run_stan_example, number=1) + idata_stan = az.from_cmdstanpy(mcmc) + neff_stan = az.ess(idata_stan.sample_stats)["lp"].item() + results += [{"algorithm": "Stan", "time": time_stan, "neff": neff_stan}] + results_df = pl.from_records(results).with_columns( + neff_per_s=pl.col("neff") / pl.col("time") + ) + print(f"Benchmark results saved to {CSV_OUTPUT_FILE}") + print("Mean results:") + results_df.write_csv(CSV_OUTPUT_FILE) + print(results_df.group_by("algorithm").mean()) if __name__ == "__main__": diff --git a/benchmarks/linear_pathway.stan b/benchmarks/linear_pathway.stan new file mode 100644 index 0000000..b38a513 --- /dev/null +++ b/benchmarks/linear_pathway.stan @@ -0,0 +1,65 @@ +functions { + real rmm(real s, real p, real km_s, real km_p, real vmax, real keq){ + real num = vmax * (s - p / keq) / km_s; + real denom = 1 + s / km_s + p / km_p; + return num / denom; + } + real ma(real s, real p, real kf, real keq){ + return kf * (s - p / keq); + } + vector dcdt(vector c, vector km, real vmax, vector keq, vector kf, vector cext){ + matrix[4, 3] S = [[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]]'; + vector[3] v = [ + ma(cext[1], c[1], kf[1], keq[1]), + rmm(c[1], c[2], km[1], km[2], vmax, keq[2]), + ma(c[1], cext[2], kf[2], keq[3]) + ]'; + vector[2] out = (S * v)[2:3]; + return out; + } +} +data { + vector[2] y; + real sd; + array[2] vector[2] prior_log_km; + array[2] real prior_log_vmax; + array[2] vector[3] prior_log_keq; + array[2] vector[2] prior_log_kf; + array[2] vector[2] prior_log_cext; + vector[2] y_guess; + real scaling_step; + real ftol; + int max_steps; +} +parameters { + vector[2] log_km; + real log_vmax; + vector[3] log_keq; + vector[2] log_kf; + vector[2] log_cext; +} +transformed parameters { + vector[2] km = exp(log_km); + real vmax = exp(log_vmax); + vector[3] keq = exp(log_keq); + vector[2] kf = exp(log_kf); + vector[2] cext = exp(log_cext); + vector[2] yhat = solve_newton_tol(dcdt, + y_guess, + scaling_step, + ftol, + max_steps, + km, + vmax, + keq, + kf, + cext); +} +model { + log_km ~ normal(prior_log_km[1], prior_log_km[2]); + log_vmax ~ normal(prior_log_vmax[1], prior_log_vmax[2]); + log_keq ~ normal(prior_log_keq[1], prior_log_keq[2]); + log_kf ~ normal(prior_log_kf[1], prior_log_kf[2]); + log_cext ~ normal(prior_log_cext[1], prior_log_cext[2]); + y ~ lognormal(log(yhat), sd); +} diff --git a/benchmarks/simple_example b/benchmarks/simple_example new file mode 100755 index 0000000..d814b18 Binary files /dev/null and b/benchmarks/simple_example differ diff --git a/benchmarks/simple_example.csv b/benchmarks/simple_example.csv new file mode 100644 index 0000000..74c21db --- /dev/null +++ b/benchmarks/simple_example.csv @@ -0,0 +1,31 @@ +algorithm,time,neff,neff_per_s +grapeNUTS,0.12466304200643208,376.1319695228191,3017.18908402229 +NUTS,0.9005323750025127,371.5385686486706,412.5765813223916 +Stan,0.34212120799929835,458.96174733293986,1341.5179667373359 +grapeNUTS,0.12385154099320062,376.1319695228191,3036.958333392626 +NUTS,0.9099195839953609,371.5385686486706,408.3202243183776 +Stan,0.4125580000109039,476.3393305826259,1154.5996697919716 +grapeNUTS,0.12458145900745876,376.1319695228191,3019.1649104085373 +NUTS,0.8996170829923358,371.5385686486706,412.9963466376682 +Stan,0.41676641699450556,330.3517860118978,792.6545243117633 +grapeNUTS,0.12741541699506342,376.1319695228191,2952.013016897256 +NUTS,0.9082949589937925,371.5385686486706,409.05056773656474 +Stan,0.2685750419914257,490.0342347403413,1824.5710066983281 +grapeNUTS,0.12431900000956375,376.1319695228191,3025.538891833779 +NUTS,0.8986174169986043,371.5385686486706,413.4557839860427 +Stan,0.3757258749974426,328.2939618726023,873.7592583285272 +grapeNUTS,0.12386395799694583,376.1319695228191,3036.6538870984045 +NUTS,0.9077164170012111,371.5385686486706,409.3112801419949 +Stan,0.3016477500059409,397.0605881308124,1316.3054858622097 +grapeNUTS,0.12421791600354481,376.1319695228191,3028.000964950059 +NUTS,0.9003118749969872,371.5385686486706,412.67762757201655 +Stan,0.4210843750042841,362.02231763343286,859.7381881713129 +grapeNUTS,0.12391262500023004,376.1319695228191,3035.4612334467197 +NUTS,0.9079646249883808,371.5385686486706,409.19938775525003 +Stan,0.42453079200640786,412.9657288082927,972.758010924351 +grapeNUTS,0.12378912499116268,376.1319695228191,3038.489605202971 +NUTS,0.9044461669982411,371.5385686486706,410.7912468486288 +Stan,0.4308973750012228,415.0890792691663,963.3130841606736 +grapeNUTS,0.12399737500527408,376.1319695228191,3033.3865495686564 +NUTS,0.9110034579935018,371.5385686486706,407.83442190986807 +Stan,0.3947812080004951,394.9687103098861,1000.4749524688388 diff --git a/benchmarks/simple_example.png b/benchmarks/simple_example.png new file mode 100644 index 0000000..7fc1e36 Binary files /dev/null and b/benchmarks/simple_example.png differ diff --git a/benchmarks/simple_example.py b/benchmarks/simple_example.py index e95b96b..8c092ad 100644 --- a/benchmarks/simple_example.py +++ b/benchmarks/simple_example.py @@ -10,29 +10,53 @@ """ +import timeit from collections import OrderedDict from functools import partial -import timeit +import logging +from pathlib import Path +import arviz as az import equinox as eqx import jax import jax.numpy as jnp import optimistix as optx - +import polars as pl from blackjax import nuts from blackjax import window_adaptation as nuts_window_adaptation from blackjax.util import run_inference_algorithm +from cmdstanpy import CmdStanModel from jax.scipy.stats import norm -from grapevine import run_grapenuts +from grapevine import run_grapenuts, get_idata # Use 64 bit floats jax.config.update("jax_enable_x64", True) +# override timeit template: +# see https://stackoverflow.com/questions/24812253/how-can-i-capture-return-value-with-python-timeit-module +timeit.template = """ +def inner(_it, _timer{init}): + {setup} + _t0 = _timer() + for _i in _it: + retval = {stmt} + _t1 = _timer() + return _t1 - _t0, retval +""" + SEED = 1234 SD = 0.05 +HERE = Path(__file__).parent +STAN_FILE = HERE / "simple_example.stan" +CSV_OUTPUT_FILE = HERE / "simple_example.csv" TRUE_PARAMS = OrderedDict(theta=jnp.array([3.0, 6.0])) DEFAULT_GUESS = jnp.array([1.0, 1.0]) +N_WARMUP = 1000 +N_SAMPLE = 1000 +INIT_STEPSIZE = 0.0001 +MAX_TREEDEPTH = 10 +TARGET_ACCEPT = 0.95 @eqx.filter_jit @@ -69,15 +93,47 @@ def simulate(key, params, guess): ) +def fit_stan(sim): + data = { + "prior_mean_theta": TRUE_PARAMS["theta"].tolist(), + "prior_sd_theta": [0.1, 0.1], + "y": sim, + "sd": SD, + "y_guess": DEFAULT_GUESS.tolist(), + "scaling_step": 1e-3, + "ftol": 1e-9, + "max_steps": 256, + } + model = CmdStanModel(stan_file=STAN_FILE) + mcmc = model.sample( + data=data, + chains=1, + inits={k: v.tolist() for k, v in TRUE_PARAMS.items()}, + iter_warmup=N_WARMUP, + iter_sampling=N_SAMPLE, + adapt_delta=TARGET_ACCEPT, + step_size=INIT_STEPSIZE, + max_treedepth=MAX_TREEDEPTH, + show_progress=False, + ) + return mcmc + + def main(): + # disable cmdstanpy logging + cmdstanpy_logger = logging.getLogger("cmdstanpy") + cmdstanpy_logger.disabled = True + # keys key = jax.random.key(SEED) key, sim_key = jax.random.split(key) - _, sim = simulate(sim_key, TRUE_PARAMS, DEFAULT_GUESS) - posterior_logdensity_gn = partial(joint_logdensity_grapenuts, obs=sim) - posterior_logdensity_nuts = partial(joint_logdensity_nuts, obs=sim) key, grapenuts_key = jax.random.split(key) key, nuts_key_warmup = jax.random.split(key) key, nuts_key_sampling = jax.random.split(key) + # simulate some data + _, sim = simulate(sim_key, TRUE_PARAMS, DEFAULT_GUESS) + # specify posteriors + posterior_logdensity_gn = partial(joint_logdensity_grapenuts, obs=sim) + posterior_logdensity_nuts = partial(joint_logdensity_nuts, obs=sim) def run_grapenuts_example(): return run_grapenuts( @@ -85,12 +141,12 @@ def run_grapenuts_example(): grapenuts_key, init_parameters=TRUE_PARAMS, default_guess=DEFAULT_GUESS, - num_warmup=1000, - num_samples=1000, - initial_step_size=0.0001, - max_num_doublings=10, + num_warmup=N_WARMUP, + num_samples=N_SAMPLE, + initial_step_size=INIT_STEPSIZE, + max_num_doublings=MAX_TREEDEPTH, is_mass_matrix_diagonal=False, - target_acceptance_rate=0.95, + target_acceptance_rate=TARGET_ACCEPT, progress_bar=False, ) @@ -99,33 +155,53 @@ def run_nuts_example(): nuts, posterior_logdensity_nuts, progress_bar=False, - initial_step_size=0.0001, - max_num_doublings=10, + initial_step_size=INIT_STEPSIZE, + max_num_doublings=MAX_TREEDEPTH, is_mass_matrix_diagonal=False, - target_acceptance_rate=0.95, + target_acceptance_rate=TARGET_ACCEPT, ) (initial_state, tuned_parameters), _ = warmup.run( nuts_key_warmup, TRUE_PARAMS, - num_steps=1000, #  type: ignore + num_steps=N_WARMUP, #  type: ignore ) kernel = nuts(posterior_logdensity_nuts, **tuned_parameters) return run_inference_algorithm( nuts_key_sampling, kernel, - 1000, + N_SAMPLE, initial_state, ) - # timers + def run_stan_example(): + return fit_stan(sim.tolist()) + + results = [] + + _ = run_stan_example() # run once for jitting _ = run_grapenuts_example() # run once for jitting - time_grapenuts = timeit.timeit(run_grapenuts_example, number=5) #  type: ignore _ = run_nuts_example() # run once for jitting - time_nuts = timeit.timeit(run_nuts_example, number=5) #  type: ignore - - # print results - print(f"Runtime for grapenuts: {round(time_grapenuts, 4)}") - print(f"Runtime for nuts: {round(time_nuts, 4)}") + for _ in range(10): + _ = run_grapenuts_example() # one more time to be safe! + time_gn, out_gn = timeit.timeit(run_grapenuts_example, number=1) + idata_gn = get_idata(*out_gn) + neff_gn = az.ess(idata_gn.sample_stats)["energy"].item() + results += [{"algorithm": "grapeNUTS", "time": time_gn, "neff": neff_gn}] + time_nuts, (_, out_nuts) = timeit.timeit(run_nuts_example, number=1) + idata_nuts = get_idata(*out_nuts) + neff_nuts = az.ess(idata_nuts.sample_stats)["energy"].item() + results += [{"algorithm": "NUTS", "time": time_nuts, "neff": neff_nuts}] + time_stan, mcmc = timeit.timeit(run_stan_example, number=1) + idata_stan = az.from_cmdstanpy(mcmc) + neff_stan = az.ess(idata_stan.sample_stats)["lp"].item() + results += [{"algorithm": "Stan", "time": time_stan, "neff": neff_stan}] + results_df = pl.from_records(results).with_columns( + neff_per_s=pl.col("neff") / pl.col("time") + ) + print(f"Benchmark results saved to {CSV_OUTPUT_FILE}") + print("Mean results:") + results_df.write_csv(CSV_OUTPUT_FILE) + print(results_df.group_by("algorithm").mean()) if __name__ == "__main__": diff --git a/benchmarks/simple_example.stan b/benchmarks/simple_example.stan new file mode 100644 index 0000000..f64ee5e --- /dev/null +++ b/benchmarks/simple_example.stan @@ -0,0 +1,33 @@ +functions { + vector system(vector y, vector theta){ + vector[2] z; + z[1] = y[1] - theta[1]; + z[2] = y[1] * y[2] - theta[2]; + return z; + } +} +data { + vector[2] prior_mean_theta; + vector[2] prior_sd_theta; + vector[2] y; + real sd; + vector[2] y_guess; + real scaling_step; + real ftol; + int max_steps; +} +parameters { + vector[2] theta; +} +transformed parameters { + vector[2] yhat = solve_newton_tol(system, + y_guess, + scaling_step, + ftol, + max_steps, + theta); +} +model { + theta ~ normal([3, 6]', 1); + y ~ normal(yhat, sd); +} diff --git a/pdm.lock b/pdm.lock index 9042e4c..428ec3c 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,10 +2,10 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "dev"] +groups = ["default", "benchmarks", "dev"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:a150c284cfae7bdedd9707da184b565598f659bc084a0613bc01e94f7a886db6" +content_hash = "sha256:f233b60e1eb71aee9ea1e89d1a0efa703ca77d8f8b4876ef825905c7a3ce7cb9" [[metadata.targets]] requires_python = ">=3.12" @@ -21,6 +21,29 @@ files = [ {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, ] +[[package]] +name = "arviz" +version = "0.20.0" +requires_python = ">=3.10" +summary = "Exploratory analysis of Bayesian models" +groups = ["default"] +dependencies = [ + "h5netcdf>=1.0.2", + "matplotlib>=3.5", + "numpy>=1.23.0", + "packaging", + "pandas>=1.5.0", + "scipy>=1.9.0", + "setuptools>=60.0.0", + "typing-extensions>=4.1.0", + "xarray-einstats>=0.3", + "xarray>=2022.6.0", +] +files = [ + {file = "arviz-0.20.0-py3-none-any.whl", hash = "sha256:5ec4f2ec180a8305ff3d1108c29e189944ab939663eb5bc3231ff199a1a5dc36"}, + {file = "arviz-0.20.0.tar.gz", hash = "sha256:a2704e0c141410fcaea1973a90cabf280f5aed5c1e10f44381ebd6c144c10a9c"}, +] + [[package]] name = "blackjax" version = "1.2.4" @@ -60,18 +83,76 @@ files = [ {file = "chex-0.1.87.tar.gz", hash = "sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d"}, ] +[[package]] +name = "cmdstanpy" +version = "1.2.4" +requires_python = ">=3.8" +summary = "Python interface to CmdStan" +groups = ["benchmarks"] +dependencies = [ + "numpy>=1.21", + "pandas", + "stanio<2.0.0,>=0.4.0", + "tqdm", +] +files = [ + {file = "cmdstanpy-1.2.4-py3-none-any.whl", hash = "sha256:ad60f8ca17050216ab7140e13aa493628d88af8a689f17253a5ad294a9826c78"}, + {file = "cmdstanpy-1.2.4.tar.gz", hash = "sha256:ad586be0b9f4c654ecbdc4af4541f4d282f99175956cda88cc5eb873719356cc"}, +] + [[package]] name = "colorama" version = "0.4.6" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" summary = "Cross-platform colored terminal text." -groups = ["dev"] -marker = "sys_platform == \"win32\"" +groups = ["benchmarks", "dev"] +marker = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "contourpy" +version = "1.3.0" +requires_python = ">=3.9" +summary = "Python library for calculating contours of 2D quadrilateral grids" +groups = ["default", "benchmarks"] +dependencies = [ + "numpy>=1.23", +] +files = [ + {file = "contourpy-1.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:570ef7cf892f0afbe5b2ee410c507ce12e15a5fa91017a0009f79f7d93a1268f"}, + {file = "contourpy-1.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:da84c537cb8b97d153e9fb208c221c45605f73147bd4cadd23bdae915042aad6"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9c0da700bf58f6e0b65312d0a5e695179a71d0163957fa381bb3c1f72972537c"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb8b141bb00fa977d9122636b16aa67d37fd40a3d8b52dd837e536d64b9a4d06"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3634b5385c6716c258d0419c46d05c8aa7dc8cb70326c9a4fb66b69ad2b52e09"}, + {file = "contourpy-1.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0dce35502151b6bd35027ac39ba6e5a44be13a68f55735c3612c568cac3805fd"}, + {file = "contourpy-1.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:aea348f053c645100612b333adc5983d87be69acdc6d77d3169c090d3b01dc35"}, + {file = "contourpy-1.3.0-cp312-cp312-win32.whl", hash = "sha256:90f73a5116ad1ba7174341ef3ea5c3150ddf20b024b98fb0c3b29034752c8aeb"}, + {file = "contourpy-1.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:b11b39aea6be6764f84360fce6c82211a9db32a7c7de8fa6dd5397cf1d079c3b"}, + {file = "contourpy-1.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3e1c7fa44aaae40a2247e2e8e0627f4bea3dd257014764aa644f319a5f8600e3"}, + {file = "contourpy-1.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:364174c2a76057feef647c802652f00953b575723062560498dc7930fc9b1cb7"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32b238b3b3b649e09ce9aaf51f0c261d38644bdfa35cbaf7b263457850957a84"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d51fca85f9f7ad0b65b4b9fe800406d0d77017d7270d31ec3fb1cc07358fdea0"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:732896af21716b29ab3e988d4ce14bc5133733b85956316fb0c56355f398099b"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d73f659398a0904e125280836ae6f88ba9b178b2fed6884f3b1f95b989d2c8da"}, + {file = "contourpy-1.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c6c7c2408b7048082932cf4e641fa3b8ca848259212f51c8c59c45aa7ac18f14"}, + {file = "contourpy-1.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f317576606de89da6b7e0861cf6061f6146ead3528acabff9236458a6ba467f8"}, + {file = "contourpy-1.3.0-cp313-cp313-win32.whl", hash = "sha256:31cd3a85dbdf1fc002280c65caa7e2b5f65e4a973fcdf70dd2fdcb9868069294"}, + {file = "contourpy-1.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:4553c421929ec95fb07b3aaca0fae668b2eb5a5203d1217ca7c34c063c53d087"}, + {file = "contourpy-1.3.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:345af746d7766821d05d72cb8f3845dfd08dd137101a2cb9b24de277d716def8"}, + {file = "contourpy-1.3.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3bb3808858a9dc68f6f03d319acd5f1b8a337e6cdda197f02f4b8ff67ad2057b"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:420d39daa61aab1221567b42eecb01112908b2cab7f1b4106a52caaec8d36973"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4d63ee447261e963af02642ffcb864e5a2ee4cbfd78080657a9880b8b1868e18"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:167d6c890815e1dac9536dca00828b445d5d0df4d6a8c6adb4a7ec3166812fa8"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:710a26b3dc80c0e4febf04555de66f5fd17e9cf7170a7b08000601a10570bda6"}, + {file = "contourpy-1.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:75ee7cb1a14c617f34a51d11fa7524173e56551646828353c4af859c56b766e2"}, + {file = "contourpy-1.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:33c92cdae89ec5135d036e7218e69b0bb2851206077251f04a6c4e0e21f03927"}, + {file = "contourpy-1.3.0.tar.gz", hash = "sha256:7ffa0db17717a8ffb127efd0c95a4362d996b892c2904db72428d5b52e1938a4"}, +] + [[package]] name = "coverage" version = "7.6.4" @@ -157,6 +238,17 @@ files = [ {file = "coverage-7.6.4.tar.gz", hash = "sha256:29fc0f17b1d3fea332f8001d4558f8214af7f1d87a345f3a133c901d60347c73"}, ] +[[package]] +name = "cycler" +version = "0.12.1" +requires_python = ">=3.8" +summary = "Composable style cycles" +groups = ["default", "benchmarks"] +files = [ + {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"}, + {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"}, +] + [[package]] name = "equinox" version = "0.11.8" @@ -211,6 +303,69 @@ files = [ {file = "fastprogress-1.0.3.tar.gz", hash = "sha256:7a17d2b438890f838c048eefce32c4ded47197ecc8ea042cecc33d3deb8022f5"}, ] +[[package]] +name = "fonttools" +version = "4.54.1" +requires_python = ">=3.8" +summary = "Tools to manipulate font files" +groups = ["default", "benchmarks"] +files = [ + {file = "fonttools-4.54.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:54471032f7cb5fca694b5f1a0aaeba4af6e10ae989df408e0216f7fd6cdc405d"}, + {file = "fonttools-4.54.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8fa92cb248e573daab8d032919623cc309c005086d743afb014c836636166f08"}, + {file = "fonttools-4.54.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a911591200114969befa7f2cb74ac148bce5a91df5645443371aba6d222e263"}, + {file = "fonttools-4.54.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93d458c8a6a354dc8b48fc78d66d2a8a90b941f7fec30e94c7ad9982b1fa6bab"}, + {file = "fonttools-4.54.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5eb2474a7c5be8a5331146758debb2669bf5635c021aee00fd7c353558fc659d"}, + {file = "fonttools-4.54.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c9c563351ddc230725c4bdf7d9e1e92cbe6ae8553942bd1fb2b2ff0884e8b714"}, + {file = "fonttools-4.54.1-cp312-cp312-win32.whl", hash = "sha256:fdb062893fd6d47b527d39346e0c5578b7957dcea6d6a3b6794569370013d9ac"}, + {file = "fonttools-4.54.1-cp312-cp312-win_amd64.whl", hash = "sha256:e4564cf40cebcb53f3dc825e85910bf54835e8a8b6880d59e5159f0f325e637e"}, + {file = "fonttools-4.54.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6e37561751b017cf5c40fce0d90fd9e8274716de327ec4ffb0df957160be3bff"}, + {file = "fonttools-4.54.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:357cacb988a18aace66e5e55fe1247f2ee706e01debc4b1a20d77400354cddeb"}, + {file = "fonttools-4.54.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e953cc0bddc2beaf3a3c3b5dd9ab7554677da72dfaf46951e193c9653e515a"}, + {file = "fonttools-4.54.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:58d29b9a294573d8319f16f2f79e42428ba9b6480442fa1836e4eb89c4d9d61c"}, + {file = "fonttools-4.54.1-cp313-cp313-win32.whl", hash = "sha256:9ef1b167e22709b46bf8168368b7b5d3efeaaa746c6d39661c1b4405b6352e58"}, + {file = "fonttools-4.54.1-cp313-cp313-win_amd64.whl", hash = "sha256:262705b1663f18c04250bd1242b0515d3bbae177bee7752be67c979b7d47f43d"}, + {file = "fonttools-4.54.1-py3-none-any.whl", hash = "sha256:37cddd62d83dc4f72f7c3f3c2bcf2697e89a30efb152079896544a93907733bd"}, + {file = "fonttools-4.54.1.tar.gz", hash = "sha256:957f669d4922f92c171ba01bef7f29410668db09f6c02111e22b2bce446f3285"}, +] + +[[package]] +name = "h5netcdf" +version = "1.4.0" +requires_python = ">=3.9" +summary = "netCDF4 via h5py" +groups = ["default"] +dependencies = [ + "h5py", + "packaging", +] +files = [ + {file = "h5netcdf-1.4.0-py3-none-any.whl", hash = "sha256:d1bb96fce5dcf42908903c9798beeef70ac84e97159eb381f1b151459313f228"}, + {file = "h5netcdf-1.4.0.tar.gz", hash = "sha256:e959c3b5bd3ca7965ce5f4383a4e038ffcb55034c63d791829bd33a5ac38a962"}, +] + +[[package]] +name = "h5py" +version = "3.12.1" +requires_python = ">=3.9" +summary = "Read and write HDF5 files from Python" +groups = ["default"] +dependencies = [ + "numpy>=1.19.3", +] +files = [ + {file = "h5py-3.12.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:06a903a4e4e9e3ebbc8b548959c3c2552ca2d70dac14fcfa650d9261c66939ed"}, + {file = "h5py-3.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7b3b8f3b48717e46c6a790e3128d39c61ab595ae0a7237f06dfad6a3b51d5351"}, + {file = "h5py-3.12.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:050a4f2c9126054515169c49cb900949814987f0c7ae74c341b0c9f9b5056834"}, + {file = "h5py-3.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c4b41d1019322a5afc5082864dfd6359f8935ecd37c11ac0029be78c5d112c9"}, + {file = "h5py-3.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:e4d51919110a030913201422fb07987db4338eba5ec8c5a15d6fab8e03d443fc"}, + {file = "h5py-3.12.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:513171e90ed92236fc2ca363ce7a2fc6f2827375efcbb0cc7fbdd7fe11fecafc"}, + {file = "h5py-3.12.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:59400f88343b79655a242068a9c900001a34b63e3afb040bd7cdf717e440f653"}, + {file = "h5py-3.12.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3e465aee0ec353949f0f46bf6c6f9790a2006af896cee7c178a8c3e5090aa32"}, + {file = "h5py-3.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba51c0c5e029bb5420a343586ff79d56e7455d496d18a30309616fdbeed1068f"}, + {file = "h5py-3.12.1-cp313-cp313-win_amd64.whl", hash = "sha256:52ab036c6c97055b85b2a242cb540ff9590bacfda0c03dd0cf0661b311f522f8"}, + {file = "h5py-3.12.1.tar.gz", hash = "sha256:326d70b53d31baa61f00b8aa5f95c2fcb9621a3ee8365d770c551a13dbbcbfdf"}, +] + [[package]] name = "iniconfig" version = "2.0.0" @@ -292,6 +447,48 @@ files = [ {file = "jaxtyping-0.2.34.tar.gz", hash = "sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf"}, ] +[[package]] +name = "kiwisolver" +version = "1.4.7" +requires_python = ">=3.8" +summary = "A fast implementation of the Cassowary constraint solver" +groups = ["default", "benchmarks"] +files = [ + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:5360cc32706dab3931f738d3079652d20982511f7c0ac5711483e6eab08efff2"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:942216596dc64ddb25adb215c3c783215b23626f8d84e8eff8d6d45c3f29f75a"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:48b571ecd8bae15702e4f22d3ff6a0f13e54d3d00cd25216d5e7f658242065ee"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad42ba922c67c5f219097b28fae965e10045ddf145d2928bfac2eb2e17673640"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:612a10bdae23404a72941a0fc8fa2660c6ea1217c4ce0dbcab8a8f6543ea9e7f"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9e838bba3a3bac0fe06d849d29772eb1afb9745a59710762e4ba3f4cb8424483"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:22f499f6157236c19f4bbbd472fa55b063db77a16cd74d49afe28992dff8c258"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693902d433cf585133699972b6d7c42a8b9f8f826ebcaf0132ff55200afc599e"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4e77f2126c3e0b0d055f44513ed349038ac180371ed9b52fe96a32aa071a5107"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:657a05857bda581c3656bfc3b20e353c232e9193eb167766ad2dc58b56504948"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4bfa75a048c056a411f9705856abfc872558e33c055d80af6a380e3658766038"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:34ea1de54beef1c104422d210c47c7d2a4999bdecf42c7b5718fbe59a4cac383"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:90da3b5f694b85231cf93586dad5e90e2d71b9428f9aad96952c99055582f520"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win32.whl", hash = "sha256:18e0cca3e008e17fe9b164b55735a325140a5a35faad8de92dd80265cd5eb80b"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win_amd64.whl", hash = "sha256:58cb20602b18f86f83a5c87d3ee1c766a79c0d452f8def86d925e6c60fbf7bfb"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win_arm64.whl", hash = "sha256:f5a8b53bdc0b3961f8b6125e198617c40aeed638b387913bf1ce78afb1b0be2a"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2e6039dcbe79a8e0f044f1c39db1986a1b8071051efba3ee4d74f5b365f5226e"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a1ecf0ac1c518487d9d23b1cd7139a6a65bc460cd101ab01f1be82ecf09794b6"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7ab9ccab2b5bd5702ab0803676a580fffa2aa178c2badc5557a84cc943fcf750"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f816dd2277f8d63d79f9c8473a79fe54047bc0467754962840782c575522224d"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf8bcc23ceb5a1b624572a1623b9f79d2c3b337c8c455405ef231933a10da379"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dea0bf229319828467d7fca8c7c189780aa9ff679c94539eed7532ebe33ed37c"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c06a4c7cf15ec739ce0e5971b26c93638730090add60e183530d70848ebdd34"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:913983ad2deb14e66d83c28b632fd35ba2b825031f2fa4ca29675e665dfecbe1"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5337ec7809bcd0f424c6b705ecf97941c46279cf5ed92311782c7c9c2026f07f"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4c26ed10c4f6fa6ddb329a5120ba3b6db349ca192ae211e882970bfc9d91420b"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c619b101e6de2222c1fcb0531e1b17bbffbe54294bfba43ea0d411d428618c27"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:3ce6b2b0231bda412463e152fc18335ba32faf4e8c23a754ad50ffa70e4091ee"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win32.whl", hash = "sha256:f4c9aee212bc89d4e13f58be11a56cc8036cabad119259d12ace14b34476fd07"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win_amd64.whl", hash = "sha256:8a3ec5aa8e38fc4c8af308917ce12c536f1c88452ce554027e55b22cbbfbff76"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win_arm64.whl", hash = "sha256:76c8094ac20ec259471ac53e774623eb62e6e1f56cd8690c67ce6ce4fcb05650"}, + {file = "kiwisolver-1.4.7.tar.gz", hash = "sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60"}, +] + [[package]] name = "lineax" version = "0.0.7" @@ -309,6 +506,45 @@ files = [ {file = "lineax-0.0.7.tar.gz", hash = "sha256:e43549a8d202432d4668afe54866741a0214ccb363487bacb2a980f72840ea48"}, ] +[[package]] +name = "matplotlib" +version = "3.9.2" +requires_python = ">=3.9" +summary = "Python plotting package" +groups = ["default", "benchmarks"] +dependencies = [ + "contourpy>=1.0.1", + "cycler>=0.10", + "fonttools>=4.22.0", + "importlib-resources>=3.2.0; python_version < \"3.10\"", + "kiwisolver>=1.3.1", + "numpy>=1.23", + "packaging>=20.0", + "pillow>=8", + "pyparsing>=2.3.1", + "python-dateutil>=2.7", +] +files = [ + {file = "matplotlib-3.9.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:ac43031375a65c3196bee99f6001e7fa5bdfb00ddf43379d3c0609bdca042df9"}, + {file = "matplotlib-3.9.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:be0fc24a5e4531ae4d8e858a1a548c1fe33b176bb13eff7f9d0d38ce5112a27d"}, + {file = "matplotlib-3.9.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf81de2926c2db243c9b2cbc3917619a0fc85796c6ba4e58f541df814bbf83c7"}, + {file = "matplotlib-3.9.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c"}, + {file = "matplotlib-3.9.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:306c8dfc73239f0e72ac50e5a9cf19cc4e8e331dd0c54f5e69ca8758550f1e1e"}, + {file = "matplotlib-3.9.2-cp312-cp312-win_amd64.whl", hash = "sha256:5413401594cfaff0052f9d8b1aafc6d305b4bd7c4331dccd18f561ff7e1d3bd3"}, + {file = "matplotlib-3.9.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:18128cc08f0d3cfff10b76baa2f296fc28c4607368a8402de61bb3f2eb33c7d9"}, + {file = "matplotlib-3.9.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4876d7d40219e8ae8bb70f9263bcbe5714415acfdf781086601211335e24f8aa"}, + {file = "matplotlib-3.9.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d9f07a80deab4bb0b82858a9e9ad53d1382fd122be8cde11080f4e7dfedb38b"}, + {file = "matplotlib-3.9.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413"}, + {file = "matplotlib-3.9.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:909645cce2dc28b735674ce0931a4ac94e12f5b13f6bb0b5a5e65e7cea2c192b"}, + {file = "matplotlib-3.9.2-cp313-cp313-win_amd64.whl", hash = "sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49"}, + {file = "matplotlib-3.9.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:37e51dd1c2db16ede9cfd7b5cabdfc818b2c6397c83f8b10e0e797501c963a03"}, + {file = "matplotlib-3.9.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b82c5045cebcecd8496a4d694d43f9cc84aeeb49fe2133e036b207abe73f4d30"}, + {file = "matplotlib-3.9.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f053c40f94bc51bc03832a41b4f153d83f2062d88c72b5e79997072594e97e51"}, + {file = "matplotlib-3.9.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbe196377a8248972f5cede786d4c5508ed5f5ca4a1e09b44bda889958b33f8c"}, + {file = "matplotlib-3.9.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5816b1e1fe8c192cbc013f8f3e3368ac56fbecf02fb41b8f8559303f24c5015e"}, + {file = "matplotlib-3.9.2.tar.gz", hash = "sha256:96ab43906269ca64a6366934106fa01534454a69e471b7bf3d79083981aaab92"}, +] + [[package]] name = "ml-dtypes" version = "0.5.0" @@ -339,7 +575,7 @@ name = "numpy" version = "2.1.2" requires_python = ">=3.10" summary = "Fundamental package for array computing in Python" -groups = ["default", "dev"] +groups = ["default", "benchmarks", "dev"] files = [ {file = "numpy-2.1.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7bf0a4f9f15b32b5ba53147369e94296f5fffb783db5aacc1be15b4bf72f43b"}, {file = "numpy-2.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b1d0fcae4f0949f215d4632be684a539859b295e2d0cb14f78ec231915d644db"}, @@ -425,12 +661,90 @@ name = "packaging" version = "24.1" requires_python = ">=3.8" summary = "Core utilities for Python packages" -groups = ["dev"] +groups = ["default", "benchmarks", "dev"] files = [ {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] +[[package]] +name = "pandas" +version = "2.2.3" +requires_python = ">=3.9" +summary = "Powerful data structures for data analysis, time series, and statistics" +groups = ["default", "benchmarks"] +dependencies = [ + "numpy>=1.22.4; python_version < \"3.11\"", + "numpy>=1.23.2; python_version == \"3.11\"", + "numpy>=1.26.0; python_version >= \"3.12\"", + "python-dateutil>=2.8.2", + "pytz>=2020.1", + "tzdata>=2022.7", +] +files = [ + {file = "pandas-2.2.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b1d432e8d08679a40e2a6d8b2f9770a5c21793a6f9f47fdd52c5ce1948a5a8a9"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a5a1595fe639f5988ba6a8e5bc9649af3baf26df3998a0abe56c02609392e0a4"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5de54125a92bb4d1c051c0659e6fcb75256bf799a732a87184e5ea503965bce3"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfcb5ee8d4d50c06a51c2fffa6cff6272098ad6540aed1a76d15fb9318194d8"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:062309c1b9ea12a50e8ce661145c6aab431b1e99530d3cd60640e255778bd43a"}, + {file = "pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3508d914817e153ad359d7e069d752cdd736a247c322d932eb89e6bc84217f28"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22a9d949bfc9a502d320aa04e5d02feab689d61da4e7764b62c30b991c42c5f0"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:800250ecdadb6d9c78eae4990da62743b857b470883fa27f652db8bdde7f6659"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6374c452ff3ec675a8f46fd9ab25c4ad0ba590b71cf0656f8b6daa5202bca3fb"}, + {file = "pandas-2.2.3-cp313-cp313-win_amd64.whl", hash = "sha256:61c5ad4043f791b61dd4752191d9f07f0ae412515d59ba8f005832a532f8736d"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3b71f27954685ee685317063bf13c7709a7ba74fc996b84fc6821c59b0f06468"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:38cf8125c40dae9d5acc10fa66af8ea6fdf760b2714ee482ca691fc66e6fcb18"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ba96630bc17c875161df3818780af30e43be9b166ce51c9a18c1feae342906c2"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db71525a1538b30142094edb9adc10be3f3e176748cd7acc2240c2f2e5aa3a4"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:15c0e1e02e93116177d29ff83e8b1619c93ddc9c49083f237d4312337a61165d"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a"}, + {file = "pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667"}, +] + +[[package]] +name = "pillow" +version = "11.0.0" +requires_python = ">=3.9" +summary = "Python Imaging Library (Fork)" +groups = ["default", "benchmarks"] +files = [ + {file = "pillow-11.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923"}, + {file = "pillow-11.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7"}, + {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6"}, + {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc"}, + {file = "pillow-11.0.0-cp312-cp312-win32.whl", hash = "sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6"}, + {file = "pillow-11.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47"}, + {file = "pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25"}, + {file = "pillow-11.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699"}, + {file = "pillow-11.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa"}, + {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f"}, + {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb"}, + {file = "pillow-11.0.0-cp313-cp313-win32.whl", hash = "sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798"}, + {file = "pillow-11.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de"}, + {file = "pillow-11.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84"}, + {file = "pillow-11.0.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b"}, + {file = "pillow-11.0.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003"}, + {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2"}, + {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a"}, + {file = "pillow-11.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8"}, + {file = "pillow-11.0.0-cp313-cp313t-win32.whl", hash = "sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8"}, + {file = "pillow-11.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904"}, + {file = "pillow-11.0.0-cp313-cp313t-win_arm64.whl", hash = "sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3"}, + {file = "pillow-11.0.0.tar.gz", hash = "sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739"}, +] + [[package]] name = "pluggy" version = "1.5.0" @@ -442,6 +756,32 @@ files = [ {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, ] +[[package]] +name = "polars" +version = "1.12.0" +requires_python = ">=3.9" +summary = "Blazingly fast DataFrame library" +groups = ["benchmarks"] +files = [ + {file = "polars-1.12.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:8f3c4e4e423c373dda07b4c8a7ff12aa02094b524767d0ca306b1eba67f2d99e"}, + {file = "polars-1.12.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:aa6f9862f0cec6353243920d9b8d858c21ec8f25f91af203dea6ff91980e140d"}, + {file = "polars-1.12.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afb03647b5160737d2119532ee8ffe825de1d19d87f81bbbb005131786f7d59b"}, + {file = "polars-1.12.0-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:ea96aba5eb3dab8f0e6abf05ab3fc2136b329261860ef8661d20f5456a2d78e0"}, + {file = "polars-1.12.0-cp39-abi3-win_amd64.whl", hash = "sha256:a228a4b320a36d03a9ec9dfe7241b6d80a2f119b2dceb1da953166655e4cf43c"}, + {file = "polars-1.12.0.tar.gz", hash = "sha256:fb5c92de1a8f7d0a3f923fe48ea89eb518bdf55315ae917012350fa072bd64f4"}, +] + +[[package]] +name = "pyparsing" +version = "3.2.0" +requires_python = ">=3.9" +summary = "pyparsing module - Classes and methods to define and execute parsing grammars" +groups = ["default", "benchmarks"] +files = [ + {file = "pyparsing-3.2.0-py3-none-any.whl", hash = "sha256:93d9577b88da0bbea8cc8334ee8b918ed014968fd2ec383e868fb8afb1ccef84"}, + {file = "pyparsing-3.2.0.tar.gz", hash = "sha256:cbf74e27246d595d9a74b186b810f6fbb86726dbf3b9532efb343f6d7294fe9c"}, +] + [[package]] name = "pytest" version = "8.3.3" @@ -476,6 +816,30 @@ files = [ {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"}, ] +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +summary = "Extensions to the standard Python datetime module" +groups = ["default", "benchmarks"] +dependencies = [ + "six>=1.5", +] +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[[package]] +name = "pytz" +version = "2024.2" +summary = "World timezone definitions, modern and historical" +groups = ["default", "benchmarks"] +files = [ + {file = "pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725"}, + {file = "pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a"}, +] + [[package]] name = "scipy" version = "1.14.1" @@ -511,12 +875,36 @@ version = "75.2.0" requires_python = ">=3.8" summary = "Easily download, build, install, upgrade, and uninstall Python packages" groups = ["default", "dev"] -marker = "python_version >= \"3.12\"" files = [ {file = "setuptools-75.2.0-py3-none-any.whl", hash = "sha256:a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8"}, {file = "setuptools-75.2.0.tar.gz", hash = "sha256:753bb6ebf1f465a1912e19ed1d41f403a79173a9acf66a42e7e6aec45c3c16ec"}, ] +[[package]] +name = "six" +version = "1.16.0" +requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +summary = "Python 2 and 3 compatibility utilities" +groups = ["default", "benchmarks"] +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + +[[package]] +name = "stanio" +version = "0.5.1" +requires_python = ">=3.8" +summary = "Utilities for preparing Stan inputs and processing Stan outputs" +groups = ["benchmarks"] +dependencies = [ + "numpy", +] +files = [ + {file = "stanio-0.5.1-py3-none-any.whl", hash = "sha256:99ad590daa5834681245c2b651716ec2e06223853661ada21430c621521c849f"}, + {file = "stanio-0.5.1.tar.gz", hash = "sha256:348d52f947dec431e118f4b601c4c5296929b86401d4d4dd5aa9373b0d4ae4ac"}, +] + [[package]] name = "toolz" version = "1.0.0" @@ -528,6 +916,20 @@ files = [ {file = "toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02"}, ] +[[package]] +name = "tqdm" +version = "4.66.6" +requires_python = ">=3.7" +summary = "Fast, Extensible Progress Meter" +groups = ["benchmarks"] +dependencies = [ + "colorama; platform_system == \"Windows\"", +] +files = [ + {file = "tqdm-4.66.6-py3-none-any.whl", hash = "sha256:223e8b5359c2efc4b30555531f09e9f2f3589bcd7fdd389271191031b49b7a63"}, + {file = "tqdm-4.66.6.tar.gz", hash = "sha256:4bdd694238bef1485ce839d67967ab50af8f9272aab687c0d7702a01da0be090"}, +] + [[package]] name = "typeguard" version = "2.13.3" @@ -549,3 +951,46 @@ files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] + +[[package]] +name = "tzdata" +version = "2024.2" +requires_python = ">=2" +summary = "Provider of IANA time zone data" +groups = ["default", "benchmarks"] +files = [ + {file = "tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd"}, + {file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"}, +] + +[[package]] +name = "xarray" +version = "2024.10.0" +requires_python = ">=3.10" +summary = "N-D labeled arrays and datasets in Python" +groups = ["default"] +dependencies = [ + "numpy>=1.24", + "packaging>=23.1", + "pandas>=2.1", +] +files = [ + {file = "xarray-2024.10.0-py3-none-any.whl", hash = "sha256:ae1d38cb44a0324dfb61e492394158ae22389bf7de9f3c174309c17376df63a0"}, + {file = "xarray-2024.10.0.tar.gz", hash = "sha256:e369e2bac430e418c2448e5b96f07da4635f98c1319aa23cfeb3fbcb9a01d2e0"}, +] + +[[package]] +name = "xarray-einstats" +version = "0.8.0" +requires_python = ">=3.10" +summary = "Stats, linear algebra and einops for xarray" +groups = ["default"] +dependencies = [ + "numpy>=1.23", + "scipy>=1.9", + "xarray>=2022.09.0", +] +files = [ + {file = "xarray_einstats-0.8.0-py3-none-any.whl", hash = "sha256:fd00552c3fb5c859b1ebc7c88a97342d3bb93d14bba904c5a9b94a4f724b76b4"}, + {file = "xarray_einstats-0.8.0.tar.gz", hash = "sha256:7f1573f9bd4d60d6e7ed9fd27c4db39da51ec49bf8ba654d4602a139a6309d7f"}, +] diff --git a/pyproject.toml b/pyproject.toml index cb73612..e8cb237 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ authors = [ dependencies = [ "blackjax>=1.2.4", "jax==0.4.33", + "arviz>=0.20.0", ] requires-python = ">=3.12" readme = "README.md" @@ -17,6 +18,11 @@ license = {text = "MIT"} demos = [ "optimistix>=0.0.8", ] +benchmarks = [ + "cmdstanpy>=1.2.4", + "polars>=1.12.0", + "matplotlib>=3.9.2", +] [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/src/grapevine/__init__.py b/src/grapevine/__init__.py index 2294c65..be37405 100644 --- a/src/grapevine/__init__.py +++ b/src/grapevine/__init__.py @@ -1,5 +1,10 @@ from grapevine.grapenuts import grapenuts_sampler from grapevine.integrator import grapevine_velocity_verlet -from grapevine.util import run_grapenuts +from grapevine.util import run_grapenuts, get_idata -__all__ = ["grapenuts_sampler", "grapevine_velocity_verlet", "run_grapenuts"] +__all__ = [ + "grapenuts_sampler", + "grapevine_velocity_verlet", + "run_grapenuts", + "get_idata", +] diff --git a/src/grapevine/util.py b/src/grapevine/util.py index 479119f..8aae120 100644 --- a/src/grapevine/util.py +++ b/src/grapevine/util.py @@ -2,12 +2,14 @@ from typing import Callable, TypedDict, Unpack +import arviz as az import equinox as eqx import jax from blackjax.types import ArrayTree from blackjax.util import run_inference_algorithm from jax._src.random import KeyArray +from jax import numpy as jnp from grapevine import grapenuts_sampler, grapevine_velocity_verlet from grapevine.adaptation import grapenuts_window_adaptation @@ -62,3 +64,24 @@ def run_grapenuts( progress_bar=progress_bar, ) return states, info + + +def get_idata(samples, info, coords=None, dims=None) -> az.InferenceData: + """Get an arviz InferenceData from a grapeNUTS output.""" + sample_dict = {k: jnp.expand_dims(v, 0) for k, v in samples.position.items()} + posterior = az.convert_to_inference_data( + sample_dict, + group="posterior", + coords=coords, + dims=dims, + ) + sample_stats = az.convert_to_inference_data( + { + "diverging": info.is_divergent, + "energy": info.energy, + }, + group="sample_stats", + ) + idata = az.concat(posterior, sample_stats) + assert idata is not None, "idata should not be None!" + return idata