-
Couldn't load subscription status.
- Fork 271
fix(gh-2036): MyPy Errors in numpyro.distributions.constraints Module
#2085
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
@fehiepsi, can you look into these changes? @juanitorduz you helped us a lot in solving #2066, you might be interested in these changes too. |
|
This is a tricky one but there is great progress :) I created a pull request to your branch @Qazalbash with a potential solution Qazalbash#3 . MyPy seems happy about it, but please see if make sense for you |
…ts in typing modules
…`Constraint` class
…tConstraint` class
…ne imports in typing modules" This reverts commit 78d8e93.
… bitwise operators
…lasses to use `NonScalarArray`
…iscrete` and `event_dim` in `Constraint` class
|
Only errors left here are coming from the statments I tried to address this issue by making typing protocols generic. I later reverted it. They are available at c12b514. |
| def __call__(self, x: NonScalarArray) -> ArrayLike: | ||
| return ordered_vector.check(x) & independent(positive, 1).check(x) | ||
| return jnp.logical_and( | ||
| ordered_vector.check(x), independent[NumLike](positive, 1).check(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think isdependent[NumLike](...) will not work for Python < 3.12, am I right?
This PR contains the resolution of mypy errors passed by #2032, in the
numpyro.distributions.constraintsmodule.I have tried to replicate the same solution proposed by @fehiepsi in #2066, i.e., the use of generics (see from 69c1ed5 till 1d6b24d).
I have slightly modified the logic. Some notes on them,
ArrayLikecontainscomplexand there is no partial order over complex numbers. MyPy was throwing errors for>,<,<=, and>=operators. They have been replaced with the equivalentjax.numpyfunction.jax.numpy.mod.jax.numpy.logical_andandjax.numpy.logical_or.__eq__method has been changed toobjectbecause the method can take any type of object; its implementation is to classify if the objects are the same or not. I have addedif not isinstance(other, ...): return Falsestatement at some places due to MyPy's errors.I tackled with following problems that, in my understanding, require some discussion,
event_dimandis_discreteare read-only properties that are modified by some constraints; therefore, I made them private attributes and introduced getter and setter methods for each.__eq__method expects return type to be abool, but we are returning arrays of booleans. I have marked them to be ignored by MyPy.ConstraintTat the end of the module, because theConstraintTobject expects aNumLikeobject, but some constraints support onlyNonScalarArray. This issue is solvable via a generic typing protocol (see changes of c12b514). It can also be seen withTransformTand subclasses ofTransform, for statementtransform_obj: TransformT = TransformClass(...), MyPy will throw an error, ifTransformClassuses anything other thanNumLike. This issue is also addressable via a generic typing protocol.jax>=0.7.2has introducedTypedNdArrayto represent constants in jaxpr (ref Include Typed<type> in ArrayLike jax-ml/jax#31989, Add Typed... types to ArrayLike. jax-ml/jax#32227). It is also a part ofArrayLiketype, and has noreshapemethod.These are all the major outlines of this PR. I will update the description if I recall any.