Skip to content

Commit

Permalink
Add type hints to dist() in discrete distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
johanbog committed Sep 30, 2023
1 parent 5ed39c3 commit 2c01984
Showing 1 changed file with 72 additions and 13 deletions.
85 changes: 72 additions & 13 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.
import warnings

from typing import Optional, TypeAlias, Union

import numpy as np
import numpy.typing as npt
import pytensor.tensor as pt

from pytensor.tensor import TensorConstant
Expand All @@ -29,6 +32,7 @@
nbinom,
poisson,
)
from pytensor.tensor.variable import TensorVariable
from scipy import stats

import pymc as pm
Expand All @@ -45,7 +49,7 @@
normal_lccdf,
normal_lcdf,
)
from pymc.distributions.distribution import Discrete
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Discrete
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.logprob.basic import logcdf, logp
from pymc.math import sigmoid
Expand All @@ -66,6 +70,8 @@
"OrderedProbit",
]

DISCRETE_DIST_PARAMETER_TYPES: TypeAlias = Union[npt.NDArray[np.int_], int, TensorVariable]


class Binomial(Discrete):
R"""
Expand Down Expand Up @@ -115,7 +121,14 @@ class Binomial(Discrete):
rv_op = binomial

@classmethod
def dist(cls, n, p=None, logit_p=None, *args, **kwargs):
def dist(
cls,
n: DISCRETE_DIST_PARAMETER_TYPES,
p: Optional[DIST_PARAMETER_TYPES] = None,
logit_p: Optional[DIST_PARAMETER_TYPES] = None,
*args,
**kwargs,
):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
Expand Down Expand Up @@ -231,7 +244,14 @@ def BetaBinom(a, b, n, x):
rv_op = betabinom

@classmethod
def dist(cls, alpha, beta, n, *args, **kwargs):
def dist(
cls,
alpha: DIST_PARAMETER_TYPES,
beta: DIST_PARAMETER_TYPES,
n: DISCRETE_DIST_PARAMETER_TYPES,
*args,
**kwargs,
):
alpha = pt.as_tensor_variable(floatX(alpha))
beta = pt.as_tensor_variable(floatX(beta))
n = pt.as_tensor_variable(intX(n))
Expand Down Expand Up @@ -337,7 +357,13 @@ class Bernoulli(Discrete):
rv_op = bernoulli

@classmethod
def dist(cls, p=None, logit_p=None, *args, **kwargs):
def dist(
cls,
p: Optional[DIST_PARAMETER_TYPES] = None,
logit_p: Optional[DIST_PARAMETER_TYPES] = None,
*args,
**kwargs,
):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
Expand Down Expand Up @@ -453,7 +479,7 @@ def DiscreteWeibull(q, b, x):
rv_op = discrete_weibull

@classmethod
def dist(cls, q, beta, *args, **kwargs):
def dist(cls, q: DIST_PARAMETER_TYPES, beta: DIST_PARAMETER_TYPES, *args, **kwargs):
q = pt.as_tensor_variable(floatX(q))
beta = pt.as_tensor_variable(floatX(beta))
return super().dist([q, beta], **kwargs)
Expand Down Expand Up @@ -542,7 +568,7 @@ class Poisson(Discrete):
rv_op = poisson

@classmethod
def dist(cls, mu, *args, **kwargs):
def dist(cls, mu: DIST_PARAMETER_TYPES, *args, **kwargs):
mu = pt.as_tensor_variable(floatX(mu))
return super().dist([mu], *args, **kwargs)

Expand Down Expand Up @@ -664,7 +690,15 @@ def NegBinom(a, m, x):
rv_op = nbinom

@classmethod
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
def dist(
cls,
mu: Optional[DIST_PARAMETER_TYPES] = None,
alpha: Optional[DIST_PARAMETER_TYPES] = None,
p: Optional[DIST_PARAMETER_TYPES] = None,
n: Optional[DIST_PARAMETER_TYPES] = None,
*args,
**kwargs,
):
n, p = cls.get_n_p(mu=mu, alpha=alpha, p=p, n=n)
n = pt.as_tensor_variable(floatX(n))
p = pt.as_tensor_variable(floatX(p))
Expand Down Expand Up @@ -777,7 +811,7 @@ class Geometric(Discrete):
rv_op = geometric

@classmethod
def dist(cls, p, *args, **kwargs):
def dist(cls, p: DIST_PARAMETER_TYPES, *args, **kwargs):
p = pt.as_tensor_variable(floatX(p))
return super().dist([p], *args, **kwargs)

Expand Down Expand Up @@ -878,7 +912,14 @@ class HyperGeometric(Discrete):
rv_op = hypergeometric

@classmethod
def dist(cls, N, k, n, *args, **kwargs):
def dist(
cls,
N: Optional[DISCRETE_DIST_PARAMETER_TYPES],
k: Optional[DISCRETE_DIST_PARAMETER_TYPES],
n: Optional[DISCRETE_DIST_PARAMETER_TYPES],
*args,
**kwargs,
):
good = pt.as_tensor_variable(intX(k))
bad = pt.as_tensor_variable(intX(N - k))
n = pt.as_tensor_variable(intX(n))
Expand Down Expand Up @@ -1015,7 +1056,13 @@ class DiscreteUniform(Discrete):
rv_op = discrete_uniform

@classmethod
def dist(cls, lower, upper, *args, **kwargs):
def dist(
cls,
lower: DISCRETE_DIST_PARAMETER_TYPES,
upper: DISCRETE_DIST_PARAMETER_TYPES,
*args,
**kwargs,
):
lower = intX(pt.floor(lower))
upper = intX(pt.floor(upper))
return super().dist([lower, upper], **kwargs)
Expand Down Expand Up @@ -1110,7 +1157,12 @@ class Categorical(Discrete):
rv_op = categorical

@classmethod
def dist(cls, p=None, logit_p=None, **kwargs):
def dist(
cls,
p: Optional[np.ndarray] = None,
logit_p: Optional[float] = None,
**kwargs,
):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
Expand Down Expand Up @@ -1185,7 +1237,7 @@ class _OrderedLogistic(Categorical):
rv_op = categorical

@classmethod
def dist(cls, eta, cutpoints, *args, **kwargs):
def dist(cls, eta: DIST_PARAMETER_TYPES, cutpoints: DIST_PARAMETER_TYPES, *args, **kwargs):
eta = pt.as_tensor_variable(floatX(eta))
cutpoints = pt.as_tensor_variable(cutpoints)

Expand Down Expand Up @@ -1291,7 +1343,14 @@ class _OrderedProbit(Categorical):
rv_op = categorical

@classmethod
def dist(cls, eta, cutpoints, sigma=1, *args, **kwargs):
def dist(
cls,
eta: DIST_PARAMETER_TYPES,
cutpoints: DIST_PARAMETER_TYPES,
sigma: DIST_PARAMETER_TYPES = 1.0,
*args,
**kwargs,
):
eta = pt.as_tensor_variable(floatX(eta))
cutpoints = pt.as_tensor_variable(cutpoints)

Expand Down

0 comments on commit 2c01984

Please sign in to comment.