jax icon indicating copy to clipboard operation
jax copied to clipboard

Fast custom division by 0

Open juliuskunze opened this issue 3 years ago • 3 comments

Is there an efficient GPU (i. e. XLA one-op) implementation of floating-point division available in JAX where division by 0 returns 0, instead of ~raising an error~ inf? In my use case, the divisor is guaranteed to be non-negative, and only 0 if the dividend is 0. Two somewhat unsatisfying implementations I considered are

  • c / jnp.where(d, d, 1), which is implemented with a broadcast_in_dim and select lax primitive, and
  • c / (d + np.finfo(d.dtype).tiny), which has a slight semantic difference and requires more bit operations than necessary.

I was hoping for something that is faster and simpler than both of these.

juliuskunze avatar Oct 23 '20 06:10 juliuskunze

Have you tried benchmarking these? Once compiled to XLA, the where based solution shows comparable execution time compared to the straightforward approach:

import numpy as np
import jax.numpy as jnp
from jax import make_jaxpr, jit

@jit
def f1(x, y):
  return x / y

@jit
def f2(x, y):
  return jnp.where(y != 0, x / y, 0)

x = jnp.array(np.random.randint(0, 10, (1000, 1000))).astype(jnp.float32)
y = jnp.array(np.random.randint(0, 10, (1000, 1000))).astype(jnp.float32)  

# trigger compilations...
f1(x, y) 
f2(x, y)

%timeit f1(x, y).block_until_ready()
# 1000 loops, best of 3: 1.03 ms per loop
%timeit f2(x, y).block_until_ready()
# 1000 loops, best of 3: 1.06 ms per loop

jakevdp avatar Oct 24 '20 00:10 jakevdp

For GPU, you could look at the PTX instruction set and see if there's an instruction or instruction sequence that you'd like XLA to emit here?

jekbradbury avatar Oct 25 '20 06:10 jekbradbury

@jakevdp I hadn't yet (probably should have), thanks for doing this, very interesting! How would you explain this behavior, i. e. is branching optimized out during compilation or just dirt-cheap? I somehow imagined branching to be more expensive on GPU/TPU (where I assume you ran this). Are you certain that overhead like host communication doesn't dominate cost?

@jekbradbury Thanks for the pointer! That made me find the spec of the PTX division ops, which all just seem to return inf on div by 0, without alternatives. The spec even has an explicit example for avoiding division by 0. Therefore, the suggested implementations are probably as good as it gets on top of CUDA/PTX. Also made me realize that it doesn't even cause a warning in JAX by default, unlike NumPy.

juliuskunze avatar Oct 25 '20 11:10 juliuskunze

I'm pretty sure Jake's suggestion is about as good as we are going to do. Let us know if it isn't fast enough!

hawkinsp avatar Aug 12 '22 21:08 hawkinsp