Peter Hawkins

Results 26 issues of Peter Hawkins

Currently JAX cannot reuse input buffers to a computation for outputs. This means that for a typical neural network training step, we require enough space to store 2 copies of...

enhancement
P2 (eventual)
NVIDIA GPU

The JAX NaN checker computes: `onp.any(onp.isnan(buf.to_py()))` i.e., this transfers the array to the host and then tests for the presence of NaNs. It would be much more bandwidth efficient to...

enhancement
P2 (eventual)
NVIDIA GPU

Currently `jaxlib` consists of essentially three things: * XLA, and its CPU and GPU backends. * A little C++ runtime around XLA, * Python bindings around both. We build these...

enhancement
build
P1 (soon)
NVIDIA GPU

``` import iree.compiler CODE = """ #loc0 = loc(unknown) module @jit_prim_fun.12 { func.func public @main(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)) -> tensor { %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b,...

bug 🐞
help wanted
integrations/mhlo

Repro: ``` from iree.compiler import compile_str CODE = """ module { func @main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, f]x[0, i, o]->[b,...

bug 🐞
help wanted
codegen
codegen/llvm

The custom call partitioner callback was not present in version 94 but is present in version 95.

pull ready

The `tpu_initializer_helper` should not be linked in if TPU support is disabled: ``` $ python build/build.py --configure_only ... TPU enabled: no Remote TPU enabled: no ... $ bazel cquery 'somepath(//build:build_wheel,@org_tensorflow//tensorflow/core/tpu:tpu_initializer_helper)'...

On GPU we dispatch XLA computations synchronously inline on the main Python thread. Frequently we find that GPU computation dispatch is slow and can take time comparable to the execution...

enhancement
P1 (soon)
NVIDIA GPU
GPU

### What happened? The following JAX program, run with `JAX_PLATFORMS=iree`, benchmarks reductions over various axes of an array vs numpy. My particular build of numpy is single-threaded. ``` import jax.numpy...

bug 🐞
awaiting-triage

Currently bazel builds of XLA from the OpenXLA repository require disabling visibility checks via the `--nocheck_visibility` argument to Bazel. It would be great if we could fix this!