jaxga icon indicating copy to clipboard operation
jaxga copied to clipboard

use segment_sum for mv_multiply

Open RobinKa opened this issue 2 years ago • 0 comments

  • Previously was using a loop and add at index which gets unrolled, now using segment_sum to sum same output indices
  • 10x faster JIT on CPU, 6x faster JIT on GPU
  • 100x slower runtime on CPU, 5x faster runtime on GPU

Should maybe add a flag for whether to use this one, very useful for large algebras where JIT takes very long because of analyzing the unrolled loop. Maybe make it the default on GPU too.

CPU results show segment_sum runtime very dependent on batch size

a_val, a_ind = jnp.array(jnp.ones([5, 10]), dtype=jnp.float32), tuple((i,) for i in range(5))
b_val, b_ind = jnp.array(jnp.ones([5, 10]), dtype=jnp.float32), tuple((i,i+1) for i in range(5))

new:
Wall time: 94 ms
10.8 µs ± 301 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

old:
Wall time: 1.04 s
10.9 µs ± 152 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

---
a_val, a_ind = jnp.array(jnp.ones([5, 100]), dtype=jnp.float32), tuple((i,) for i in range(5))
b_val, b_ind = jnp.array(jnp.ones([5, 100]), dtype=jnp.float32), tuple((i,i+1) for i in range(5))

new:
Wall time: 227 ms
48.6 µs ± 640 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

old:
Wall time: 676 ms
11.6 µs ± 464 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
---
a_val, a_ind = jnp.array(jnp.ones([10, 100]), dtype=jnp.float32), tuple((i,) for i in range(5))
b_val, b_ind = jnp.array(jnp.ones([10, 100]), dtype=jnp.float32), tuple((i,i+1) for i in range(5))

new:
Wall time: 261 ms
49.1 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

old:
Wall time: 687 ms
11.6 µs ± 196 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
---
a_val, a_ind = jnp.array(jnp.ones([10, 1000]), dtype=jnp.float32), tuple((i,) for i in range(5))
b_val, b_ind = jnp.array(jnp.ones([10, 1000]), dtype=jnp.float32), tuple((i,i+1) for i in range(5))

new:
Wall time: 256 ms
1.19 ms ± 69.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

old:
Wall time: 558 ms
16.9 µs ± 234 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

RobinKa avatar Nov 06 '21 17:11 RobinKa