numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

fix(gh-2036): MyPy Errors in `numpyro.distributions.constraints` Module

Open Qazalbash opened this issue 5 months ago • 5 comments

This PR contains the resolution of mypy errors passed by https://github.com/pyro-ppl/numpyro/pull/2032, in the numpyro.distributions.constraints module.

I have tried to replicate the same solution proposed by @fehiepsi in #2066, i.e., the use of generics (see from 69c1ed5eeeb9ddc0841f723d7089cc51a0cd5a16 till 1d6b24df2a340efa6118910043a821a7dbdafcfb).

I have slightly modified the logic. Some notes on them,

  1. ArrayLike contains complex and there is no partial order over complex numbers. MyPy was throwing errors for >, <, <=, and >= operators. They have been replaced with the equivalent jax.numpy function.
  2. There is no mod operation between arrays and integers; it has been replaced with jax.numpy.mod.
  3. Bitwise operations have been replaced with jax.numpy.logical_and and jax.numpy.logical_or.
  4. The type of the argument in the __eq__ method has been changed to object because the method can take any type of object; its implementation is to classify if the objects are the same or not. I have added if not isinstance(other, ...): return False statement at some places due to MyPy's errors.

I tackled with following problems that, in my understanding, require some discussion,

  1. event_dim and is_discrete are read-only properties that are modified by some constraints; therefore, I made them private attributes and introduced getter and setter methods for each.
  2. __eq__ method expects return type to be a bool, but we are returning arrays of booleans. I have marked them to be ignored by MyPy.
  3. We can not ducktype the constraint object with ConstraintT at the end of the module, because the ConstraintT object expects a NumLike object, but some constraints support only NonScalarArray. This issue is solvable via a generic typing protocol (see changes of c12b51439b44b9b7fb4a7f0eaebac3f9c1149876). It can also be seen with TransformT and subclasses of Transform, for statement transform_obj: TransformT = TransformClass(...), MyPy will throw an error, if TransformClass uses anything other than NumLike. This issue is also addressable via a generic typing protocol.
  4. jax>=0.7.2 has introduced TypedNdArray to represent constants in jaxpr (ref https://github.com/jax-ml/jax/issues/31989, https://github.com/jax-ml/jax/pull/32227). It is also a part of ArrayLike type, and has no reshape method.

These are all the major outlines of this PR. I will update the description if I recall any.

Qazalbash avatar Oct 21 '25 21:10 Qazalbash

@fehiepsi, can you look into these changes?


@juanitorduz you helped us a lot in solving #2066, you might be interested in these changes too.

Qazalbash avatar Oct 21 '25 21:10 Qazalbash

This is a tricky one but there is great progress :) I created a pull request to your branch @Qazalbash with a potential solution https://github.com/Qazalbash/numpyro/pull/3 . MyPy seems happy about it, but please see if make sense for you

juanitorduz avatar Oct 21 '25 21:10 juanitorduz

Only errors left here are coming from the statments constrain_obj: ConstraintT = ConstraintClass(...), when ConstraintClass uses NonScalarArray type. Because ConstraintT has NumLike and expects similar from ConstraintClass. Same problem can be seen in TransformT.

I tried to address this issue by making typing protocols generic. I later reverted it. They are available at c12b51439b44b9b7fb4a7f0eaebac3f9c1149876.

Qazalbash avatar Oct 27 '25 08:10 Qazalbash

Hi @fehiepsi, I am a little busy with my grad school applications, can I update you after Dec 15?

Qazalbash avatar Dec 01 '25 13:12 Qazalbash

Absolutely, please take your time!

fehiepsi avatar Dec 01 '25 13:12 fehiepsi