neuropod icon indicating copy to clipboard operation
neuropod copied to clipboard

Add support for JAX/XLA/Trax?

Open VivekPanyam opened this issue 4 years ago • 2 comments

A few months ago, I noticed that JAX (https://github.com/google/jax) and Trax (https://github.com/google/trax) have been getting more popular.

JAX functions which are compiled (https://github.com/google/jax#compilation-with-jit) can be turned into an XLA HLO proto (see https://github.com/google/jax/issues/1871) which can be run from C++

Trax can use TF, Numpy, or JAX under the hood so I don't think we need to do much additional work to add support for it.

Concretely, we'd need to add a backend for XLA and packagers for JAX and Trax

VivekPanyam avatar Apr 08 '20 22:04 VivekPanyam

Note: Flax (https://github.com/google/flax) is another DL library built on top of JAX

VivekPanyam avatar Apr 08 '20 22:04 VivekPanyam

Is it right understanding that this is GPU/TPU optimization only?

vkuzmin-uber avatar Apr 09 '20 06:04 vkuzmin-uber