fix(gh-2036): MyPy Errors in `numpyro.distributions.constraints` Module
This PR contains the resolution of mypy errors passed by https://github.com/pyro-ppl/numpyro/pull/2032, in the numpyro.distributions.constraints module.
I have tried to replicate the same solution proposed by @fehiepsi in #2066, i.e., the use of generics (see from 69c1ed5eeeb9ddc0841f723d7089cc51a0cd5a16 till 1d6b24df2a340efa6118910043a821a7dbdafcfb).
I have slightly modified the logic. Some notes on them,
-
ArrayLikecontainscomplexand there is no partial order over complex numbers. MyPy was throwing errors for>,<,<=, and>=operators. They have been replaced with the equivalentjax.numpyfunction. - There is no mod operation between arrays and integers; it has been replaced with
jax.numpy.mod. - Bitwise operations have been replaced with
jax.numpy.logical_andandjax.numpy.logical_or. - The type of the argument in the
__eq__method has been changed toobjectbecause the method can take any type of object; its implementation is to classify if the objects are the same or not. I have addedif not isinstance(other, ...): return Falsestatement at some places due to MyPy's errors.
I tackled with following problems that, in my understanding, require some discussion,
-
event_dimandis_discreteare read-only properties that are modified by some constraints; therefore, I made them private attributes and introduced getter and setter methods for each. -
__eq__method expects return type to be abool, but we are returning arrays of booleans. I have marked them to be ignored by MyPy. - We can not ducktype the constraint object with
ConstraintTat the end of the module, because theConstraintTobject expects aNumLikeobject, but some constraints support onlyNonScalarArray. This issue is solvable via a generic typing protocol (see changes of c12b51439b44b9b7fb4a7f0eaebac3f9c1149876). It can also be seen withTransformTand subclasses ofTransform, for statementtransform_obj: TransformT = TransformClass(...), MyPy will throw an error, ifTransformClassuses anything other thanNumLike. This issue is also addressable via a generic typing protocol. -
jax>=0.7.2has introducedTypedNdArrayto represent constants in jaxpr (ref https://github.com/jax-ml/jax/issues/31989, https://github.com/jax-ml/jax/pull/32227). It is also a part ofArrayLiketype, and has noreshapemethod.
These are all the major outlines of this PR. I will update the description if I recall any.
@fehiepsi, can you look into these changes?
@juanitorduz you helped us a lot in solving #2066, you might be interested in these changes too.
This is a tricky one but there is great progress :) I created a pull request to your branch @Qazalbash with a potential solution https://github.com/Qazalbash/numpyro/pull/3 . MyPy seems happy about it, but please see if make sense for you
Only errors left here are coming from the statments constrain_obj: ConstraintT = ConstraintClass(...), when ConstraintClass uses NonScalarArray type. Because ConstraintT has NumLike and expects similar from ConstraintClass. Same problem can be seen in TransformT.
I tried to address this issue by making typing protocols generic. I later reverted it. They are available at c12b51439b44b9b7fb4a7f0eaebac3f9c1149876.
Hi @fehiepsi, I am a little busy with my grad school applications, can I update you after Dec 15?
Absolutely, please take your time!