From e1fb10cea7abed61cea2598b66e2bfa2fd70c248 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Tue, 10 Dec 2024 11:14:34 -0700 Subject: [PATCH 1/2] Fixed a bug in `named_arrays.random.binomial()` where the units weren't being handled properly. --- .../_scalars/scalar_named_array_functions.py | 48 ++++++++++++++++++- .../_vectors/vector_named_array_functions.py | 7 ++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/named_arrays/_scalars/scalar_named_array_functions.py b/named_arrays/_scalars/scalar_named_array_functions.py index 17488df..1a3dc26 100644 --- a/named_arrays/_scalars/scalar_named_array_functions.py +++ b/named_arrays/_scalars/scalar_named_array_functions.py @@ -36,7 +36,6 @@ na.random.uniform, na.random.normal, na.random.poisson, - na.random.binomial, ) PLT_PLOT_LIKE_FUNCTIONS = ( na.plt.plot, @@ -579,6 +578,53 @@ def random( ) +@_implements(na.random.binomial) +def random_binomial( + n: int | u.Quantity | na.AbstractScalarArray, + p: float | na.AbstractScalarArray, + shape_random: None | dict[str, int] = None, + seed: None | int = None, +): + try: + n = scalars._normalize(n) + p = scalars._normalize(p) + except na.ScalarTypeError: + return NotImplemented + + if shape_random is None: + shape_random = dict() + + shape_base = na.shape_broadcasted(n, p) + shape = na.broadcast_shapes(shape_base, shape_random) + + n = n.ndarray_aligned(shape) + p = p.ndarray_aligned(shape) + + unit = na.unit(n) + + if unit is not None: + n = n.value + + if seed is None: + func = np.random.binomial + else: + func = np.random.default_rng(seed).binomial + + value = func( + n=n, + p=p, + size=tuple(shape.values()), + ) + + if unit is not None: + value = value << unit + + return na.ScalarArray( + ndarray=value, + axes=tuple(shape.keys()), + ) + + def plt_plot_like( func: Callable, *args: na.AbstractScalarArray, diff --git a/named_arrays/_vectors/vector_named_array_functions.py b/named_arrays/_vectors/vector_named_array_functions.py index b320a95..dd4811c 100644 --- a/named_arrays/_vectors/vector_named_array_functions.py +++ b/named_arrays/_vectors/vector_named_array_functions.py @@ -22,7 +22,12 @@ OutputT = TypeVar("OutputT", bound="float | u.Quantity | na.AbstractVectorArray") ASARRAY_LIKE_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.ASARRAY_LIKE_FUNCTIONS -RANDOM_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.RANDOM_FUNCTIONS +RANDOM_FUNCTIONS = ( + na.random.uniform, + na.random.normal, + na.random.poisson, + na.random.binomial, +) PLT_PLOT_LIKE_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.PLT_PLOT_LIKE_FUNCTIONS NDFILTER_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.NDFILTER_FUNCTIONS HANDLED_FUNCTIONS = dict() From b898bd0cd176c50a438ba8e6ffc4a62aad465e2a Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Tue, 10 Dec 2024 12:04:02 -0700 Subject: [PATCH 2/2] coverage --- .../uncertainties_named_array_functions.py | 7 ++- named_arrays/tests/test_random.py | 57 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 named_arrays/tests/test_random.py diff --git a/named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py b/named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py index 2e8ab55..bb8f5e2 100644 --- a/named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py +++ b/named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py @@ -19,7 +19,12 @@ ] ASARRAY_LIKE_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.ASARRAY_LIKE_FUNCTIONS -RANDOM_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.RANDOM_FUNCTIONS +RANDOM_FUNCTIONS = ( + na.random.uniform, + na.random.normal, + na.random.poisson, + na.random.binomial, +) PLT_PLOT_LIKE_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.PLT_PLOT_LIKE_FUNCTIONS NDFILTER_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.NDFILTER_FUNCTIONS HANDLED_FUNCTIONS = dict() diff --git a/named_arrays/tests/test_random.py b/named_arrays/tests/test_random.py new file mode 100644 index 0000000..7fbd8fb --- /dev/null +++ b/named_arrays/tests/test_random.py @@ -0,0 +1,57 @@ +import pytest +import numpy as np +import astropy.units as u +import named_arrays as na + + +@pytest.mark.parametrize( + argnames="n", + argvalues=[ + 10, + (11 * u.photon).astype(int), + na.ScalarArray(12), + (na.arange(1, 10, axis="x") << u.photon).astype(int), + na.Cartesian2dVectorArray(10, 11), + ], +) +@pytest.mark.parametrize( + argnames="p", + argvalues=[ + 0.5, + na.ScalarArray(0.51), + na.linspace(0.4, 0.5, axis="p", num=5), + na.UniformUncertainScalarArray(0.5, width=0.1), + na.Cartesian2dVectorArray(0.5, 0.6), + ], +) +@pytest.mark.parametrize( + argnames="shape_random", + argvalues=[ + None, + dict(_s=6), + ], +) +@pytest.mark.parametrize( + argnames="seed", + argvalues=[ + None, + 42, + ], +) +def test_binomial( + n: int | u.Quantity | na.AbstractScalar | na.AbstractVectorArray, + p: float | na.AbstractScalar | na.AbstractVectorArray, + shape_random: None | dict[str, int], + seed: None | int, +): + result = na.random.binomial( + n=n, + p=p, + shape_random=shape_random, + seed=seed, + ) + + assert na.unit(result) == na.unit(n) + + assert np.all(result >= 0) + assert np.all(result <= n)