From 2c019844ad9dfdcb7f7a32298b51e6d5409b2091 Mon Sep 17 00:00:00 2001 From: Johannes Bogen Date: Sat, 30 Sep 2023 13:47:28 -0700 Subject: [PATCH] Add type hints to dist() in discrete distributions --- pymc/distributions/discrete.py | 85 ++++++++++++++++++++++++++++------ 1 file changed, 72 insertions(+), 13 deletions(-) diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index d4c1e4585a1..c3a7d726133 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -13,7 +13,10 @@ # limitations under the License. import warnings +from typing import Optional, TypeAlias, Union + import numpy as np +import numpy.typing as npt import pytensor.tensor as pt from pytensor.tensor import TensorConstant @@ -29,6 +32,7 @@ nbinom, poisson, ) +from pytensor.tensor.variable import TensorVariable from scipy import stats import pymc as pm @@ -45,7 +49,7 @@ normal_lccdf, normal_lcdf, ) -from pymc.distributions.distribution import Discrete +from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Discrete from pymc.distributions.shape_utils import rv_size_is_none from pymc.logprob.basic import logcdf, logp from pymc.math import sigmoid @@ -66,6 +70,8 @@ "OrderedProbit", ] +DISCRETE_DIST_PARAMETER_TYPES: TypeAlias = Union[npt.NDArray[np.int_], int, TensorVariable] + class Binomial(Discrete): R""" @@ -115,7 +121,14 @@ class Binomial(Discrete): rv_op = binomial @classmethod - def dist(cls, n, p=None, logit_p=None, *args, **kwargs): + def dist( + cls, + n: DISCRETE_DIST_PARAMETER_TYPES, + p: Optional[DIST_PARAMETER_TYPES] = None, + logit_p: Optional[DIST_PARAMETER_TYPES] = None, + *args, + **kwargs, + ): if p is not None and logit_p is not None: raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") elif p is None and logit_p is None: @@ -231,7 +244,14 @@ def BetaBinom(a, b, n, x): rv_op = betabinom @classmethod - def dist(cls, alpha, beta, n, *args, **kwargs): + def dist( + cls, + alpha: DIST_PARAMETER_TYPES, + beta: DIST_PARAMETER_TYPES, + n: DISCRETE_DIST_PARAMETER_TYPES, + *args, + **kwargs, + ): alpha = pt.as_tensor_variable(floatX(alpha)) beta = pt.as_tensor_variable(floatX(beta)) n = pt.as_tensor_variable(intX(n)) @@ -337,7 +357,13 @@ class Bernoulli(Discrete): rv_op = bernoulli @classmethod - def dist(cls, p=None, logit_p=None, *args, **kwargs): + def dist( + cls, + p: Optional[DIST_PARAMETER_TYPES] = None, + logit_p: Optional[DIST_PARAMETER_TYPES] = None, + *args, + **kwargs, + ): if p is not None and logit_p is not None: raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") elif p is None and logit_p is None: @@ -453,7 +479,7 @@ def DiscreteWeibull(q, b, x): rv_op = discrete_weibull @classmethod - def dist(cls, q, beta, *args, **kwargs): + def dist(cls, q: DIST_PARAMETER_TYPES, beta: DIST_PARAMETER_TYPES, *args, **kwargs): q = pt.as_tensor_variable(floatX(q)) beta = pt.as_tensor_variable(floatX(beta)) return super().dist([q, beta], **kwargs) @@ -542,7 +568,7 @@ class Poisson(Discrete): rv_op = poisson @classmethod - def dist(cls, mu, *args, **kwargs): + def dist(cls, mu: DIST_PARAMETER_TYPES, *args, **kwargs): mu = pt.as_tensor_variable(floatX(mu)) return super().dist([mu], *args, **kwargs) @@ -664,7 +690,15 @@ def NegBinom(a, m, x): rv_op = nbinom @classmethod - def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs): + def dist( + cls, + mu: Optional[DIST_PARAMETER_TYPES] = None, + alpha: Optional[DIST_PARAMETER_TYPES] = None, + p: Optional[DIST_PARAMETER_TYPES] = None, + n: Optional[DIST_PARAMETER_TYPES] = None, + *args, + **kwargs, + ): n, p = cls.get_n_p(mu=mu, alpha=alpha, p=p, n=n) n = pt.as_tensor_variable(floatX(n)) p = pt.as_tensor_variable(floatX(p)) @@ -777,7 +811,7 @@ class Geometric(Discrete): rv_op = geometric @classmethod - def dist(cls, p, *args, **kwargs): + def dist(cls, p: DIST_PARAMETER_TYPES, *args, **kwargs): p = pt.as_tensor_variable(floatX(p)) return super().dist([p], *args, **kwargs) @@ -878,7 +912,14 @@ class HyperGeometric(Discrete): rv_op = hypergeometric @classmethod - def dist(cls, N, k, n, *args, **kwargs): + def dist( + cls, + N: Optional[DISCRETE_DIST_PARAMETER_TYPES], + k: Optional[DISCRETE_DIST_PARAMETER_TYPES], + n: Optional[DISCRETE_DIST_PARAMETER_TYPES], + *args, + **kwargs, + ): good = pt.as_tensor_variable(intX(k)) bad = pt.as_tensor_variable(intX(N - k)) n = pt.as_tensor_variable(intX(n)) @@ -1015,7 +1056,13 @@ class DiscreteUniform(Discrete): rv_op = discrete_uniform @classmethod - def dist(cls, lower, upper, *args, **kwargs): + def dist( + cls, + lower: DISCRETE_DIST_PARAMETER_TYPES, + upper: DISCRETE_DIST_PARAMETER_TYPES, + *args, + **kwargs, + ): lower = intX(pt.floor(lower)) upper = intX(pt.floor(upper)) return super().dist([lower, upper], **kwargs) @@ -1110,7 +1157,12 @@ class Categorical(Discrete): rv_op = categorical @classmethod - def dist(cls, p=None, logit_p=None, **kwargs): + def dist( + cls, + p: Optional[np.ndarray] = None, + logit_p: Optional[float] = None, + **kwargs, + ): if p is not None and logit_p is not None: raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") elif p is None and logit_p is None: @@ -1185,7 +1237,7 @@ class _OrderedLogistic(Categorical): rv_op = categorical @classmethod - def dist(cls, eta, cutpoints, *args, **kwargs): + def dist(cls, eta: DIST_PARAMETER_TYPES, cutpoints: DIST_PARAMETER_TYPES, *args, **kwargs): eta = pt.as_tensor_variable(floatX(eta)) cutpoints = pt.as_tensor_variable(cutpoints) @@ -1291,7 +1343,14 @@ class _OrderedProbit(Categorical): rv_op = categorical @classmethod - def dist(cls, eta, cutpoints, sigma=1, *args, **kwargs): + def dist( + cls, + eta: DIST_PARAMETER_TYPES, + cutpoints: DIST_PARAMETER_TYPES, + sigma: DIST_PARAMETER_TYPES = 1.0, + *args, + **kwargs, + ): eta = pt.as_tensor_variable(floatX(eta)) cutpoints = pt.as_tensor_variable(cutpoints)