From 85784b8b838ebf70c4be40127f4b9695cfed8636 Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Thu, 29 Aug 2024 11:48:02 +0200 Subject: [PATCH 1/8] Type `get_tau_sigma` correctly and fix comment --- pymc/distributions/continuous.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 0af7193b32..b18cae58d2 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -55,10 +55,10 @@ ) from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.utils import normalize_size_param -from pytensor.tensor.variable import TensorConstant +from pytensor.tensor.variable import TensorConstant, TensorVariable from pymc.logprob.abstract import _logprob_helper -from pymc.logprob.basic import icdf +from pymc.logprob.basic import TensorLike, icdf from pymc.pytensorf import normalize_rng_param try: @@ -214,7 +214,9 @@ def assert_negative_support(var, label, distname, value=-1e-6): return Assert(msg)(var, pt.all(pt.ge(var, 0.0))) -def get_tau_sigma(tau=None, sigma=None): +def get_tau_sigma( + tau: TensorLike | None = None, sigma: TensorLike | None = None +) -> tuple[TensorVariable, TensorVariable]: r""" Find precision and standard deviation. The link between the two parameterizations is given by the inverse relationship: @@ -241,13 +243,14 @@ def get_tau_sigma(tau=None, sigma=None): sigma = pt.as_tensor_variable(1.0) tau = pt.as_tensor_variable(1.0) elif tau is None: + assert sigma is not None # Just for type checker sigma = pt.as_tensor_variable(sigma) # Keep tau negative, if sigma was negative, so that it will # fail when used tau = (sigma**-2.0) * pt.sign(sigma) else: tau = pt.as_tensor_variable(tau) - # Keep tau negative, if sigma was negative, so that it will + # Keep sigma negative, if tau was negative, so that it will # fail when used sigma = pt.abs(tau) ** -0.5 * pt.sign(tau) From 8ea8b6cc1f7adc796618812ced04fe76f0243f90 Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Thu, 29 Aug 2024 12:03:31 +0200 Subject: [PATCH 2/8] Fix tuple type Use of bound_arg_indices shows that this should be a (start, end) tuple, not a list. --- pymc/distributions/continuous.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index b18cae58d2..30abbbfc8e 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -152,7 +152,7 @@ class BoundedContinuous(Continuous): """Base class for bounded continuous distributions""" # Indices of the arguments that define the lower and upper bounds of the distribution - bound_args_indices: list[int] | None = None + bound_args_indices: tuple[int, int] | None = None @_default_transform.register(PositiveContinuous) From b2692ea36c5210e77c915de1c150b02989e31dbb Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Thu, 29 Aug 2024 12:04:33 +0200 Subject: [PATCH 3/8] Fix incorrect type hinting syntax I don't want to touch the `= None` part. I could add ` | None` to the type hint but I suspect the `= None` is unnecessary. --- pymc/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index df6f5efa08..d84e8a643a 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -423,7 +423,7 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) -> class Distribution(metaclass=DistributionMeta): """Statistical distribution""" - rv_op: [RandomVariable, SymbolicRandomVariable] = None + rv_op: RandomVariable | SymbolicRandomVariable = None rv_type: MetaType = None def __new__( From af91def837e765a65bb17cf62372286f05912450 Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Mon, 2 Sep 2024 09:16:00 +0200 Subject: [PATCH 4/8] Type rv_op based on behaviour rather than type This is a bit like using a Protocol class, we say that whatever is passed to rv_op, it should return a TensorVariable. --- pymc/distributions/distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index d84e8a643a..d427f828c5 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -423,8 +423,8 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) -> class Distribution(metaclass=DistributionMeta): """Statistical distribution""" - rv_op: RandomVariable | SymbolicRandomVariable = None - rv_type: MetaType = None + rv_op: Callable[..., TensorVariable] | None = None + rv_type: MetaType | None = None def __new__( cls, From 2e84b30130f858f86b84ea32e1e6e5096f02f3ee Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Mon, 2 Sep 2024 13:20:21 +0200 Subject: [PATCH 5/8] Try removing Optional / = None completely I don't think these are necessary at all. The tests will tell us so. --- pymc/distributions/distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index d427f828c5..1813492428 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -423,8 +423,8 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) -> class Distribution(metaclass=DistributionMeta): """Statistical distribution""" - rv_op: Callable[..., TensorVariable] | None = None - rv_type: MetaType | None = None + rv_op: Callable[..., TensorVariable] + rv_type: MetaType def __new__( cls, From 157c29a60edbe670e245b1208e3e97749669da8d Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Thu, 5 Sep 2024 07:48:55 +0200 Subject: [PATCH 6/8] Try setting rv_op: Any = None --- pymc/distributions/distribution.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 1813492428..ca209f5f7f 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -21,7 +21,7 @@ from abc import ABCMeta from collections.abc import Callable, Sequence from functools import singledispatch -from typing import TypeAlias +from typing import Any, TypeAlias import numpy as np @@ -122,6 +122,8 @@ def __new__(cls, name, bases, clsdict): ) class_change_dist_size = clsdict.get("change_dist_size") + if class_change_dist_size is not None: + raise ValueError("HAHAHA") if class_change_dist_size: @_change_dist_size.register(rv_type) @@ -423,8 +425,12 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) -> class Distribution(metaclass=DistributionMeta): """Statistical distribution""" - rv_op: Callable[..., TensorVariable] - rv_type: MetaType + # rv_op and _type are set to None via the DistributionMeta.__new__ + # if not specified as class attributes in subclasses of Distribution. + # rv_op can either be a class (see the Normal class) or a method + # (see the Censored class), both callable to return a TensorVariable. + rv_op: Any = None + rv_type: MetaType | None = None def __new__( cls, From 7bd35a80b5ac4fe411825d0bf5aaa22741de9e0f Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Tue, 8 Oct 2024 09:24:58 +0200 Subject: [PATCH 7/8] Remove debug statement --- pymc/distributions/distribution.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index ca209f5f7f..ef08bc889e 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -122,8 +122,6 @@ def __new__(cls, name, bases, clsdict): ) class_change_dist_size = clsdict.get("change_dist_size") - if class_change_dist_size is not None: - raise ValueError("HAHAHA") if class_change_dist_size: @_change_dist_size.register(rv_type) From 1575243d30952d384009dd1e23f0bb4f2c6b0f4d Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Tue, 8 Oct 2024 13:01:09 +0200 Subject: [PATCH 8/8] Correct type definition for bound_arg_indices Allows optional lower or upper bounds. --- pymc/distributions/continuous.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 30abbbfc8e..58bca3829f 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -152,7 +152,7 @@ class BoundedContinuous(Continuous): """Base class for bounded continuous distributions""" # Indices of the arguments that define the lower and upper bounds of the distribution - bound_args_indices: tuple[int, int] | None = None + bound_args_indices: tuple[int | None, int | None] | None = None @_default_transform.register(PositiveContinuous)