Skip to content

Commit 328864b

Browse files
larryshamalamabrandonwillard
authored andcommitted
Add ChiSquareRV JAX implementation
1 parent 18f65fd commit 328864b

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

aesara/link/jax/dispatch/random.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,6 @@ def jax_sample_fn_wald(op):
369369
def sample_fn(rng, size, dtype, *parameters):
370370
rng_key = rng["jax_state"]
371371
rng_key, sampling_key = jax.random.split(rng_key, 2)
372-
373372
mean, scale = parameters
374373

375374
key1, key2 = jax.random.split(sampling_key, 2)
@@ -390,6 +389,21 @@ def sample_fn(rng, size, dtype, *parameters):
390389
return sample_fn
391390

392391

392+
@jax_sample_fn.register(aer.ChiSquareRV)
393+
def jax_sample_fn_chisquare(op):
394+
"""JAX implementation of `ChiSquareRV`"""
395+
396+
def sample_fn(rng, size, dtype, *parameters):
397+
rng_key = rng["jax_state"]
398+
rng_key, sampling_key = jax.random.split(rng_key, 2)
399+
(df,) = parameters
400+
sample = jax.random.gamma(sampling_key, df / 2, size, dtype) * 2
401+
rng["jax_state"] = rng_key
402+
return (rng, sample)
403+
404+
return sample_fn
405+
406+
393407
@jax_sample_fn.register(aer.GeometricRV)
394408
def jax_sample_fn_geometric(op):
395409
"""JAX implementation of `GeometricRV`."""

tests/link/jax/test_random.py

+13
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,19 @@ def test_random_updates(rng_ctor):
9797
lambda *args: args,
9898
None,
9999
),
100+
(
101+
aer.chisquare,
102+
[
103+
set_test_value(
104+
at.dvector(),
105+
np.array([1.0, 2.0], dtype=np.float64),
106+
)
107+
],
108+
(2,),
109+
"chi2",
110+
lambda *args: args,
111+
50_000,
112+
),
100113
(
101114
aer.exponential,
102115
[

0 commit comments

Comments
 (0)