jax
jax copied to clipboard
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Before, on CPU: After, on CPU: Following up on #11866. I don't know if these small compile time differences matter, but hey, roofshots! And the jaxpr pretty-printing win is enough...
**Before:** **After:** No more `lt`, `add`, `select`, or `convert_element_type`! This might actually improve compilation time a bit; on CPU, **the tiny benchmark added went from ~15.5ms to 13ms on my...
Make all pmap tests pass with Array!
Fix a typo in jax2tf README.
[mhlo] Unify different versions of round op (viz., mhlo::round_nearest_afz & mhlo::round_nearest_even) into one. Currently in MHLO, we have two version of round ops with different rounding modes. This CL unifies...
Move tensorflow/core/platform/{default, google, windows} to tensorflow/tsl/platform/...
**Work In Progress** This JEP introduces a roadmap for type annotations in JAX, addressing a number of goals, non-goals, and design decisions that need to be made. Rendered JEP here:...