diff --git a/README.md b/README.md index 24547cc..94458fa 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ # grapevine +[![Tests](https://github.com/dtu-qmcm/grapevine/actions/workflows/run_tests.yml/badge.svg)](https://github.com/dtu-qmcm/grapevine/actions/workflows/run_tests.yml) +[![Project Status: WIP – Initial development is in progress, but there has not yet been a stable, usable release suitable for the public.](https://www.repostatus.org/badges/latest/wip.svg)](https://www.repostatus.org/#wip) +[![Supported Python versions: 3.12 and newer](https://img.shields.io/badge/python->=3.12-blue.svg)](https://www.python.org/) JAX/Blackjax implementation of the grapevine method for reusing the solutions of guessing problems embedded in Hamiltonian trajectories. diff --git a/pdm.lock b/pdm.lock index 81c2aeb..9042e4c 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:8ec91c6840ebaff75127bb41d8140975411d6f7891c0dca13e1560b35605a668" +content_hash = "sha256:a150c284cfae7bdedd9707da184b565598f659bc084a0613bc01e94f7a886db6" [[metadata.targets]] requires_python = ">=3.12" @@ -157,6 +157,22 @@ files = [ {file = "coverage-7.6.4.tar.gz", hash = "sha256:29fc0f17b1d3fea332f8001d4558f8214af7f1d87a345f3a133c901d60347c73"}, ] +[[package]] +name = "equinox" +version = "0.11.8" +requires_python = ">=3.9" +summary = "Elegant easy-to-use neural networks in JAX." +groups = ["dev"] +dependencies = [ + "jax!=0.4.27,>=0.4.13", + "jaxtyping>=0.2.20", + "typing-extensions>=4.5.0", +] +files = [ + {file = "equinox-0.11.8-py3-none-any.whl", hash = "sha256:552292b473956693e8e8973bdae9b58aaec54fd48e192921beb82995e3a9c995"}, + {file = "equinox-0.11.8.tar.gz", hash = "sha256:d1e91a05e41bb9538db72a8e15d26daf958348c26714533434c88c5ec0c0b0ef"}, +] + [[package]] name = "etils" version = "1.10.0" @@ -262,6 +278,37 @@ files = [ {file = "jaxopt-0.8.3.tar.gz", hash = "sha256:4b06dfa6f915a4f3291699606245af6069371a48dc5c92d4c507840d62990646"}, ] +[[package]] +name = "jaxtyping" +version = "0.2.34" +requires_python = "~=3.9" +summary = "Type annotations and runtime checking for shape and dtype of JAX arrays, and PyTrees." +groups = ["dev"] +dependencies = [ + "typeguard==2.13.3", +] +files = [ + {file = "jaxtyping-0.2.34-py3-none-any.whl", hash = "sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5"}, + {file = "jaxtyping-0.2.34.tar.gz", hash = "sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf"}, +] + +[[package]] +name = "lineax" +version = "0.0.7" +requires_python = "~=3.9" +summary = "Linear solvers in JAX and Equinox." +groups = ["dev"] +dependencies = [ + "equinox>=0.11.5", + "jax>=0.4.26", + "jaxtyping>=0.2.20", + "typing-extensions>=4.5.0", +] +files = [ + {file = "lineax-0.0.7-py3-none-any.whl", hash = "sha256:c261977fd2104010ff34b7353deef22961da3ca46f341f158567dc2bbb8c2372"}, + {file = "lineax-0.0.7.tar.gz", hash = "sha256:e43549a8d202432d4668afe54866741a0214ccb363487bacb2a980f72840ea48"}, +] + [[package]] name = "ml-dtypes" version = "0.5.0" @@ -355,6 +402,24 @@ files = [ {file = "optax-0.2.3.tar.gz", hash = "sha256:ec7ab925440b0c5a512e1f24fba0fb3e7d760a7fd5d2496d7a691e9d37da01d9"}, ] +[[package]] +name = "optimistix" +version = "0.0.9" +requires_python = "~=3.9" +summary = "Nonlinear optimisation in JAX and Equinox." +groups = ["dev"] +dependencies = [ + "equinox>=0.11.7", + "jax>=0.4.28", + "jaxtyping>=0.2.23", + "lineax>=0.0.6", + "typing-extensions>=4.5.0", +] +files = [ + {file = "optimistix-0.0.9-py3-none-any.whl", hash = "sha256:d47d47ef520a4c4e37ff050c89793401cb7142009ef917a986bdc55f9d5db2aa"}, + {file = "optimistix-0.0.9.tar.gz", hash = "sha256:b5e3ce9e6d111f399269c7565c68125bdf6bae9b8ff6cacf30232231159354ef"}, +] + [[package]] name = "packaging" version = "24.1" @@ -463,6 +528,17 @@ files = [ {file = "toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02"}, ] +[[package]] +name = "typeguard" +version = "2.13.3" +requires_python = ">=3.5.3" +summary = "Run-time type checker for Python" +groups = ["dev"] +files = [ + {file = "typeguard-2.13.3-py3-none-any.whl", hash = "sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1"}, + {file = "typeguard-2.13.3.tar.gz", hash = "sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4"}, +] + [[package]] name = "typing-extensions" version = "4.12.2" diff --git a/pyproject.toml b/pyproject.toml index 4edf830..cb73612 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dev = [ "pytest>=8.3.3", "pytest-cov>=5.0.0", "chex>=0.1.87", + "optimistix>=0.0.9", ] [tool.hatch.metadata] allow-direct-references = true