diff --git a/CHANGELOG.md b/CHANGELOG.md index c23e11a4b8..6c67cfa887 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `ModeData.dispersion` and `ModeSolverData.dispersion` are calculated together with the group index. - String matching feature `contains_str` to `assert_log_level` testing utility. +### Changed +- `jax` and `jaxlib` versions bumped to `0.4.*`. + ## [2.5.0] - 2023-12-13 ### Added diff --git a/requirements/jax.txt b/requirements/jax.txt index 00b87da262..cebdb68aa5 100644 --- a/requirements/jax.txt +++ b/requirements/jax.txt @@ -3,14 +3,14 @@ -r core.txt # regular case (linux, macos) -jaxlib>=0.3.14,<=0.4.14; platform_system != "Windows" -jax[cpu]>=0.3.14,<=0.4.14; platform_system != "Windows" +jaxlib>=0.3.14,==0.4.*; platform_system != "Windows" +jax[cpu]>=0.3.14,==0.4.*; platform_system != "Windows" # we downgrade to 0.3 for windows users because the only binaries for windows are 0.3 currently. jaxlib==0.3.14; platform_system == "Windows" and python_version < "3.9" jax[cpu]==0.3.14; platform_system == "Windows" and python_version < "3.9" -# windows users running python > 3.9 can install same jax version as unix +# windows users running python > 3.9 must use older version for now jaxlib>=0.3.14,<=0.4.14; platform_system == "Windows" and python_version >= "3.9" jax[cpu]>=0.3.14,<=0.4.14; platform_system == "Windows" and python_version >= "3.9"