neuropod
neuropod copied to clipboard
Add support for JAX/XLA/Trax?
A few months ago, I noticed that JAX (https://github.com/google/jax) and Trax (https://github.com/google/trax) have been getting more popular.
JAX functions which are compiled (https://github.com/google/jax#compilation-with-jit) can be turned into an XLA HLO proto (see https://github.com/google/jax/issues/1871) which can be run from C++
Trax can use TF, Numpy, or JAX under the hood so I don't think we need to do much additional work to add support for it.
Concretely, we'd need to add a backend for XLA and packagers for JAX and Trax
Note: Flax (https://github.com/google/flax) is another DL library built on top of JAX
Is it right understanding that this is GPU/TPU optimization only?