jax icon indicating copy to clipboard operation
jax copied to clipboard

[jax2tf] Implement jax2tf(pjit) for experimental_native_lowering

Open gnecula opened this issue 2 years ago • 0 comments

This implementation is for the case jax2tf.convert(pjit(f_jax)), that is, the pjit appears at the top-level of the function to be lowered.

gnecula avatar Sep 11 '22 18:09 gnecula