hummingbird icon indicating copy to clipboard operation
hummingbird copied to clipboard

Add JAX backend

Open TuanNguyen27 opened this issue 5 years ago • 5 comments

I'm wondering if there's any interest in supporting JAX ? I imagine there could be some additional speed up through their just-in-time compilation. If so, i'd love to contribute on this!

TuanNguyen27 avatar Jun 16 '20 04:06 TuanNguyen27

Hey @TuanNguyen27, in general we would like to support as many backend as possible. We already have a TVM version which we will open source soon and I think JAX will be a great addition. My only concerns are:

  1. which performance improvements will we get by using JAX?
  2. who will be supporting the JAX backend as we will be adding new operators?

Do you think you can get a quick and dirty experiment and explore the performance improvements using JAX? Do you think you can also give us an idea on whether you could be helping us in supporting JAX going forward in case? I am asking this because we don't have JAX experience.

interesaaat avatar Jun 16 '20 17:06 interesaaat

Hi @interesaaat , I do have some experience with JAX (not a power-user, but quite comfortable), and I'm interested in experimenting with potential JAX performance improvements. If this experiment has promising results I am willing to commit to point 2) of your concerns.

Perhaps let's start with the proof of concept before we decide if this is worth pursuing further. Could you give me some pointers on how I can benchmark this, e.g what's the process of adding a new backend, and some basic operations that should perform reasonably fast ? Thanks for the discussion :)

TuanNguyen27 avatar Jun 16 '20 18:06 TuanNguyen27

@interesaaat pinging you again in case you could give me some pointers to get started.

TuanNguyen27 avatar Jun 18 '20 14:06 TuanNguyen27

Hey @TuanNguyen27, sorry but I somehow missed your reply. I think it would be good to test 2 tree implementations: GEMM and TreTrav. The implementations are here (the decision tree or GBDT implementation are actually doing the same, I will eventually merge them).

If you don't want to directly add a backend but just to experimentation you can try to re-implement those 2 in JAX and run some experiment. I would start with GEMM since it's a bit simpler \ shorter.

If you want to add the backend, we need to build some infrastructure in place because now it assumes that PyTorch is the only backend (we have a PR for ONNX, but it's a bit different because that goes from ONNX to ONNX, not like from Sklearn to ONNX).

interesaaat avatar Jun 18 '20 16:06 interesaaat

Thanks for the pointer ! I agree with the approach, let me get back to you once I have a working JAX port of GEMM and TreTrav.

TuanNguyen27 avatar Jun 18 '20 17:06 TuanNguyen27