Skip to content

Conversation

@Qazalbash
Copy link
Contributor

@Qazalbash Qazalbash commented Oct 21, 2025

This PR contains the resolution of mypy errors passed by #2032, in the numpyro.distributions.constraints module.

I have tried to replicate the same solution proposed by @fehiepsi in #2066, i.e., the use of generics (see from 69c1ed5 till 1d6b24d).

I have slightly modified the logic. Some notes on them,

  1. ArrayLike contains complex and there is no partial order over complex numbers. MyPy was throwing errors for >, <, <=, and >= operators. They have been replaced with the equivalent jax.numpy function.
  2. There is no mod operation between arrays and integers; it has been replaced with jax.numpy.mod.
  3. Bitwise operations have been replaced with jax.numpy.logical_and and jax.numpy.logical_or.
  4. The type of the argument in the __eq__ method has been changed to object because the method can take any type of object; its implementation is to classify if the objects are the same or not. I have added if not isinstance(other, ...): return False statement at some places due to MyPy's errors.

I tackled with following problems that, in my understanding, require some discussion,

  1. event_dim and is_discrete are read-only properties that are modified by some constraints; therefore, I made them private attributes and introduced getter and setter methods for each.
  2. __eq__ method expects return type to be a bool, but we are returning arrays of booleans. I have marked them to be ignored by MyPy.
  3. We can not ducktype the constraint object with ConstraintT at the end of the module, because the ConstraintT object expects a NumLike object, but some constraints support only NonScalarArray. This issue is solvable via a generic typing protocol (see changes of c12b514). It can also be seen with TransformT and subclasses of Transform, for statement transform_obj: TransformT = TransformClass(...), MyPy will throw an error, if TransformClass uses anything other than NumLike. This issue is also addressable via a generic typing protocol.
  4. jax>=0.7.2 has introduced TypedNdArray to represent constants in jaxpr (ref Include Typed<type> in ArrayLike jax-ml/jax#31989, Add Typed... types to ArrayLike. jax-ml/jax#32227). It is also a part of ArrayLike type, and has no reshape method.

These are all the major outlines of this PR. I will update the description if I recall any.

@Qazalbash
Copy link
Contributor Author

@fehiepsi, can you look into these changes?


@juanitorduz you helped us a lot in solving #2066, you might be interested in these changes too.

@juanitorduz
Copy link
Collaborator

This is a tricky one but there is great progress :) I created a pull request to your branch @Qazalbash with a potential solution Qazalbash#3 . MyPy seems happy about it, but please see if make sense for you

@Qazalbash
Copy link
Contributor Author

Only errors left here are coming from the statments constrain_obj: ConstraintT = ConstraintClass(...), when ConstraintClass uses NonScalarArray type. Because ConstraintT has NumLike and expects similar from ConstraintClass. Same problem can be seen in TransformT.

I tried to address this issue by making typing protocols generic. I later reverted it. They are available at c12b514.

def __call__(self, x: NonScalarArray) -> ArrayLike:
return ordered_vector.check(x) & independent(positive, 1).check(x)
return jnp.logical_and(
ordered_vector.check(x), independent[NumLike](positive, 1).check(x)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think isdependent[NumLike](...) will not work for Python < 3.12, am I right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants