jax
jax copied to clipboard
[jax2tf] Implement jax2tf(pjit) for experimental_native_lowering
This implementation is for the case jax2tf.convert(pjit(f_jax)), that is, the pjit
appears at the top-level of the function to be lowered.