jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.Array: support duck-typed isinstance checks

Open jakevdp opened this issue 3 years ago • 4 comments

Why? As part of the type promotion discussion in JAX (#11859 & #12049), @patrick-kidger and I have been discussing what API jax should support for both annotations and isinstance checks; we think the best approach would be to eventually make jax.Array unify all these purposes.

This is a simple enhancement to jax.Array that allows it to be used in isinstance checks in traced functions, similar to how isinstance(obj, jnp.ndarray) works currently.

@yashk2810, I'm curious whether you see any issues with adding a mechanism like this to the Array class? Do you think this idea of unifying type checking and instance checks around the eventual jax.Array object has merit? Do you have any hesitations here?

jakevdp avatar Sep 07 '22 16:09 jakevdp

Array is in the process of being lowering to C++. How does this affect that? (given that you are adding a metaclass to Array).

Is it possible to keep the behavior of Array similar to DA in this case because DA is also in C++ and it doesn't have a metaclass on it right?

yashk2810 avatar Sep 07 '22 16:09 yashk2810

DeviceArray does not behave this way, but we want Array to behave this way. If we lower Array to C++, we should add an equivalent metaclass in the C++ definition, so that this new test still passes. Does that sound reasonable? If not, it changes the plan we're landing on in https://github.com/google/jax/pull/11859 so it would be good to know now.

jakevdp avatar Sep 07 '22 17:09 jakevdp

we should add an equivalent metaclass in the C++ definition. Does that sound reasonable?

@hawkinsp or @cky9301 Is it okay if we add the metaclass to C++? I don't know if that is possible. I am fine with it existing in Python though.

yashk2810 avatar Sep 07 '22 17:09 yashk2810

I've found a couple examples of overriding metaclasses / __instancecheck__ within pybind11; one example is in torch: https://github.com/pytorch/pytorch/blob/31ef8ddb8c4467f5b8698ef1eb9bb8bab7056855/torch/csrc/tensor/python_tensor.cpp#L149-L177

jakevdp avatar Sep 07 '22 17:09 jakevdp

Replaced by https://github.com/google/jax/pull/12300

jakevdp avatar Oct 07 '22 14:10 jakevdp