Add JAX backend
I'm wondering if there's any interest in supporting JAX ? I imagine there could be some additional speed up through their just-in-time compilation. If so, i'd love to contribute on this!
Hey @TuanNguyen27, in general we would like to support as many backend as possible. We already have a TVM version which we will open source soon and I think JAX will be a great addition. My only concerns are:
- which performance improvements will we get by using JAX?
- who will be supporting the JAX backend as we will be adding new operators?
Do you think you can get a quick and dirty experiment and explore the performance improvements using JAX? Do you think you can also give us an idea on whether you could be helping us in supporting JAX going forward in case? I am asking this because we don't have JAX experience.
Hi @interesaaat , I do have some experience with JAX (not a power-user, but quite comfortable), and I'm interested in experimenting with potential JAX performance improvements. If this experiment has promising results I am willing to commit to point 2) of your concerns.
Perhaps let's start with the proof of concept before we decide if this is worth pursuing further. Could you give me some pointers on how I can benchmark this, e.g what's the process of adding a new backend, and some basic operations that should perform reasonably fast ? Thanks for the discussion :)
@interesaaat pinging you again in case you could give me some pointers to get started.
Hey @TuanNguyen27, sorry but I somehow missed your reply. I think it would be good to test 2 tree implementations: GEMM and TreTrav. The implementations are here (the decision tree or GBDT implementation are actually doing the same, I will eventually merge them).
If you don't want to directly add a backend but just to experimentation you can try to re-implement those 2 in JAX and run some experiment. I would start with GEMM since it's a bit simpler \ shorter.
If you want to add the backend, we need to build some infrastructure in place because now it assumes that PyTorch is the only backend (we have a PR for ONNX, but it's a bit different because that goes from ONNX to ONNX, not like from Sklearn to ONNX).
Thanks for the pointer ! I agree with the approach, let me get back to you once I have a working JAX port of GEMM and TreTrav.