jax
jax copied to clipboard
jax.Array: support duck-typed isinstance checks
Why? As part of the type promotion discussion in JAX (#11859 & #12049), @patrick-kidger and I have been discussing what API jax should support for both annotations and isinstance checks; we think the best approach would be to eventually make jax.Array unify all these purposes.
This is a simple enhancement to jax.Array that allows it to be used in isinstance checks in traced functions, similar to how isinstance(obj, jnp.ndarray) works currently.
@yashk2810, I'm curious whether you see any issues with adding a mechanism like this to the Array class? Do you think this idea of unifying type checking and instance checks around the eventual jax.Array object has merit? Do you have any hesitations here?
Array is in the process of being lowering to C++. How does this affect that? (given that you are adding a metaclass to Array).
Is it possible to keep the behavior of Array similar to DA in this case because DA is also in C++ and it doesn't have a metaclass on it right?
DeviceArray does not behave this way, but we want Array to behave this way. If we lower Array to C++, we should add an equivalent metaclass in the C++ definition, so that this new test still passes. Does that sound reasonable? If not, it changes the plan we're landing on in https://github.com/google/jax/pull/11859 so it would be good to know now.
we should add an equivalent metaclass in the C++ definition. Does that sound reasonable?
@hawkinsp or @cky9301 Is it okay if we add the metaclass to C++? I don't know if that is possible. I am fine with it existing in Python though.
I've found a couple examples of overriding metaclasses / __instancecheck__ within pybind11; one example is in torch: https://github.com/pytorch/pytorch/blob/31ef8ddb8c4467f5b8698ef1eb9bb8bab7056855/torch/csrc/tensor/python_tensor.cpp#L149-L177
Replaced by https://github.com/google/jax/pull/12300