Skip to content

Test problems with JAX #768

@matteobachetti

Description

@matteobachetti

@dhuppenkothen @Gaurav17Joshi

  1. we are getting this new deprecation warning. Better to fix it asap, so that we have a stable enough API

    DeprecationWarning: jax.linear_util.transformation is deprecated. Use jax.extend.linear_util.transformation instead.

  2. new test problem:

ERROR ../../.tox/py311-test-alldeps-cov/lib/python3.11/site-packages/stingray/modeling/tests/test_gpmodeling.py::TestGPResult::test_sample - jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[5] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was sample_U at /home/runner/work/stingray/stingray/.tox/py311-test-alldeps-cov/lib/python3.11/site-packages/jaxns/model.py:47 traced for jit.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions