jax icon indicating copy to clipboard operation
jax copied to clipboard

efficient untrue batching of `random_bit_generator`

Open froystig opened this issue 1 year ago • 1 comments

The batching rule for the random_bit_generator primitive, over a batch of keys, emits a loop (via lax.map): https://github.com/google/jax/blob/42ae8432185bf03f61ddd2e7bc279d3abb5247fd/jax/_src/lax/control_flow/loops.py#L2012-L2024

This is a workaround to the corresponding RandomBitGenerator HLO not being batchable. But looping violates the operational expectations of vmap, namely that everything is vectorized. And downstream, the surprise performance hit when switching RNG implementations isn't great.

We could consider a few options:

  1. Emit an unrolled loop. Drawbacks: grows the program size with the batch size.
  2. Generate a batch of random numbers from a single key in the batch, dropping the remaining keys in the batch. Drawbacks: this violates vmap semantics considering the random values generated, although the output is "statistically" the same in a sense.

Let's try number 2.

The RBG operation is already non-portable across platforms and XLA flags. In some cases the random generation is affected by sharding. So arguably, callers opting into RBG RNGs already expect unusual semantics. By contrast, it's uncommon that anyone expects the performance hit.

cc @mattjj, @dlwh

froystig avatar Dec 21 '23 18:12 froystig

This ought to address #16792

froystig avatar Feb 17 '24 18:02 froystig