About the requirement of jax version.
I noticed that all versions of dm-pix library require jax>=0.4.16, even the 0.1.0 version. Is that necessary?
I wonder if this can be modified so that I can use dm-pix with jax==0.3.14.
Hi @AlbertHuyb, I think we expect a version >= 0.2.17 (although this might be outdated, I need to properly check it) and I think that if you pip install dm-pix you will get Jax transitively from chex, see here. This is not ideal, I understand it, but I don't think there's anything that prevent you from using the particular version that you require. You can always give it a go and see what happens 😄 Also, you probably gain much more if you install jax yourself with the optimization that suits best your need as described in here (and that's also why we don't enforce installation through pip, but we get it transitively from chex). Hope this helps! 🙏