jax
jax copied to clipboard
Fast custom division by 0
Is there an efficient GPU (i. e. XLA one-op) implementation of floating-point division available in JAX where division by 0
returns 0
, instead of ~raising an error~ inf
? In my use case, the divisor is guaranteed to be non-negative, and only 0
if the dividend is 0
. Two somewhat unsatisfying implementations I considered are
-
c / jnp.where(d, d, 1)
, which is implemented with abroadcast_in_dim
andselect
lax primitive, and -
c / (d + np.finfo(d.dtype).tiny)
, which has a slight semantic difference and requires more bit operations than necessary.
I was hoping for something that is faster and simpler than both of these.
Have you tried benchmarking these? Once compiled to XLA, the where
based solution shows comparable execution time compared to the straightforward approach:
import numpy as np
import jax.numpy as jnp
from jax import make_jaxpr, jit
@jit
def f1(x, y):
return x / y
@jit
def f2(x, y):
return jnp.where(y != 0, x / y, 0)
x = jnp.array(np.random.randint(0, 10, (1000, 1000))).astype(jnp.float32)
y = jnp.array(np.random.randint(0, 10, (1000, 1000))).astype(jnp.float32)
# trigger compilations...
f1(x, y)
f2(x, y)
%timeit f1(x, y).block_until_ready()
# 1000 loops, best of 3: 1.03 ms per loop
%timeit f2(x, y).block_until_ready()
# 1000 loops, best of 3: 1.06 ms per loop
For GPU, you could look at the PTX instruction set and see if there's an instruction or instruction sequence that you'd like XLA to emit here?
@jakevdp I hadn't yet (probably should have), thanks for doing this, very interesting! How would you explain this behavior, i. e. is branching optimized out during compilation or just dirt-cheap? I somehow imagined branching to be more expensive on GPU/TPU (where I assume you ran this). Are you certain that overhead like host communication doesn't dominate cost?
@jekbradbury Thanks for the pointer! That made me find the spec of the PTX division ops, which all just seem to return inf
on div by 0
, without alternatives. The spec even has an explicit example for avoiding division by 0. Therefore, the suggested implementations are probably as good as it gets on top of CUDA/PTX. Also made me realize that it doesn't even cause a warning in JAX by default, unlike NumPy.
I'm pretty sure Jake's suggestion is about as good as we are going to do. Let us know if it isn't fast enough!