-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add way to set backend fn
random generators
#7629
base: main
Are you sure you want to change the base?
Conversation
Need to raise an error if the linker is JAXLinker as |
Perhaps a more explicit name like |
Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
I don't like errorring out here. Maybe a warning will be good for now, until we figure out how to control jax's random keys. |
@ricardoV94, I can't get the function used by the trace to compile with JAX or NUMBA mode. Could you have a look at the test and tell me if I'm messing up the config context? |
This PR depends on #7540. Do not merge before that one has been merged.
Description
This PR adds the
set_function_rngs
function that takes a compiled pytensor function, looks for any random generators in it, and makes a copy that sets all generators to spawned versions of a supplied generator. The implementation is taken from @ricardoV94's function here (and that's why I listed him as coauthor of this commit). This function is then used ininit_traces
and base trace initialization to make the backend'sfn
have a properly seeded random generator. Thanks to this change, we can reproducible sampling results even when the function hasDeterministic
that depend on raw pytensor random variables.Related Issue
Checklist
Type of change