Peter Hawkins
Peter Hawkins
Currently JAX cannot reuse input buffers to a computation for outputs. This means that for a typical neural network training step, we require enough space to store 2 copies of...
The JAX NaN checker computes: `onp.any(onp.isnan(buf.to_py()))` i.e., this transfers the array to the host and then tests for the presence of NaNs. It would be much more bandwidth efficient to...
Currently `jaxlib` consists of essentially three things: * XLA, and its CPU and GPU backends. * A little C++ runtime around XLA, * Python bindings around both. We build these...
``` import iree.compiler CODE = """ #loc0 = loc(unknown) module @jit_prim_fun.12 { func.func public @main(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)) -> tensor { %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b,...
Repro: ``` from iree.compiler import compile_str CODE = """ module { func @main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, f]x[0, i, o]->[b,...
The custom call partitioner callback was not present in version 94 but is present in version 95.
The `tpu_initializer_helper` should not be linked in if TPU support is disabled: ``` $ python build/build.py --configure_only ... TPU enabled: no Remote TPU enabled: no ... $ bazel cquery 'somepath(//build:build_wheel,@org_tensorflow//tensorflow/core/tpu:tpu_initializer_helper)'...
On GPU we dispatch XLA computations synchronously inline on the main Python thread. Frequently we find that GPU computation dispatch is slow and can take time comparable to the execution...
### What happened? The following JAX program, run with `JAX_PLATFORMS=iree`, benchmarks reductions over various axes of an array vs numpy. My particular build of numpy is single-threaded. ``` import jax.numpy...
Currently bazel builds of XLA from the OpenXLA repository require disabling visibility checks via the `--nocheck_visibility` argument to Bazel. It would be great if we could fix this!