Skip to content

Commit 18f65fd

Browse files
Smit-createbrandonwillard
authored andcommitted
Add standard_exp and standard_t
1 parent 587bc9f commit 18f65fd

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

aesara/tensor/random/basic.py

+6
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,9 @@ def __call__(self, scale=1.0, size=None, **kwargs):
778778

779779

780780
exponential = ExponentialRV()
781+
standard_exponential = get_partial_wrapper(
782+
exponential, "standard_exponential", scale=1.0
783+
)
781784

782785

783786
class WeibullRV(RandomVariable):
@@ -1525,6 +1528,7 @@ def rng_fn_scipy(cls, rng, df, loc, scale, size):
15251528

15261529

15271530
t = StudentTRV()
1531+
standard_t = get_partial_wrapper(t, "standard_t", loc=0.0, scale=1.0)
15281532

15291533

15301534
class BernoulliRV(ScipyRandomVariable):
@@ -2229,8 +2233,10 @@ def __call__(self, x, **kwargs):
22292233
"triangular",
22302234
"uniform",
22312235
"standard_cauchy",
2236+
"standard_exponential",
22322237
"standard_gamma",
22332238
"standard_normal",
2239+
"standard_t",
22342240
"negative_binomial",
22352241
"gengamma",
22362242
"t",

tests/tensor/random/test_basic.py

+42
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@
5151
random,
5252
rayleigh,
5353
standard_cauchy,
54+
standard_exponential,
5455
standard_gamma,
5556
standard_normal,
57+
standard_t,
5658
t,
5759
triangular,
5860
truncexpon,
@@ -449,6 +451,18 @@ def test_exponential_default_args():
449451
compare_sample_values(exponential)
450452

451453

454+
@pytest.mark.parametrize(
455+
"size",
456+
[
457+
(10_000,),
458+
None,
459+
(100, 100),
460+
],
461+
)
462+
def test_std_exponential_args(size):
463+
compare_sample_values(standard_exponential, size=size)
464+
465+
452466
def test_rayleigh_default_args():
453467
compare_sample_values(rayleigh)
454468

@@ -1031,6 +1045,34 @@ def test_t_samples(df, loc, scale, size):
10311045
)
10321046

10331047

1048+
@pytest.mark.parametrize(
1049+
"df, size",
1050+
[
1051+
(
1052+
np.array(2, dtype=config.floatX),
1053+
None,
1054+
),
1055+
(
1056+
np.array(2.30, dtype=config.floatX),
1057+
[2, 3],
1058+
),
1059+
(
1060+
np.full((1, 2), 5, dtype=config.floatX),
1061+
None,
1062+
),
1063+
],
1064+
)
1065+
def test_standard_t_samples(df, size):
1066+
compare_sample_values(
1067+
standard_t,
1068+
df,
1069+
size=size,
1070+
test_fn=lambda df, size=None, random_state=None, **kwargs: t.rng_fn(
1071+
random_state, df, 0.0, 1.0, size
1072+
),
1073+
)
1074+
1075+
10341076
@pytest.mark.parametrize(
10351077
"p, size",
10361078
[

tests/tensor/random/test_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def test_basics(self, rng_ctor):
111111
assert hasattr(random, "standard_normal")
112112
assert hasattr(random, "standard_cauchy")
113113
assert hasattr(random, "standard_gamma")
114+
assert hasattr(random, "standard_exponential")
115+
assert hasattr(random, "standard_t")
114116

115117
with pytest.raises(AttributeError):
116118
np_random = RandomStream(namespace=np, rng_ctor=rng_ctor)

0 commit comments

Comments
 (0)