Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Implement Scan's as_while in JAX #710

@sokol11

Description

@sokol11

Getting this traceback when trying to run sampling_jax.numpyro_nuts(...) (see below). I installed aesara and pymc from development branch source. This is a Colab machine with Python 3.7. Any help would be greatly appreciated!

UnfilteredStackTrace: TypeError: __init__() missing 1 required positional argument: 'as_while'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------


The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)

/usr/local/lib/python3.7/dist-packages/aesara/link/jax/dispatch.py in scan(*outer_inputs)
    418     def scan(*outer_inputs):
    419         scan_args = ScanArgs(
--> 420             list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
    421         )
    422 

TypeError: __init__() missing 1 required positional argument: 'as_while'

Metadata

Metadata

Assignees

No one assigned

    Labels

    JAXInvolves JAX transpilationScanInvolves the `Scan` `Op`bugSomething isn't workinghelp wantedExtra attention is neededimportant

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions