nx icon indicating copy to clipboard operation
nx copied to clipboard

Backend-agnostic stateless PRNGs

Open polvalente opened this issue 3 years ago • 2 comments

As discussed in #331, we need an approach for having portable PRNGs that can be applied for any backend.

We definitely need the ability to seed RNGs, and I'm confident this approach will work for Torchx and other backends, but I'm not 100% sure how it will extend to EXLA and other Tensor compilers that can't depend on a stateful RNG. As an example, we can set the RNG seed as an executable run option in EXLA, but that would apply to the entire execution and not individual calls to random_uniform within defn. It would also need to be passed as a compile option rather than directly to calls to random_x.

I think it probably is best to move forward instead with Jax-style stateless PRNGs because we can implement them with our current API and have a solution that extends to every compiler and backend. The JAX PRNGs also tout themselves as perfect for distribution/parallel computing, and I think that aligns with some of our future goals. In order to rework this from the EXLA perspective, it would probably involve using RngBitGenerator and reimplementing our current random functions in terms of that and other primitives.

Originally posted by @seanmor5 in https://github.com/elixir-nx/nx/issues/331#issuecomment-798775291

polvalente avatar Mar 13 '21 20:03 polvalente

Leaving some notes here

JAX PRNG https://github.com/google/jax/blob/537e35b0fa2c2126cdd22a2e346b65ce11395f80/jax/_src/prng.py
uses ThreeFry https://bashtage.github.io/randomgen/bit_generators/threefry.html which allows parallel application use cases via its key system
XLA RngBitGenerator implementation https://github.com/google/jax/blob/537e35b0fa2c2126cdd22a2e346b65ce11395f80/jax/_src/prng.py#L543
https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator

If we set the RngSeed in XLA, https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor.cpp#L346, would this apply to all future kernels as well (as it seems to be set on the DeviceCtx)? Would future kernels reset seed to 0 if the context does not restart? (Just thinking of a temporary quick stopgap that can be depreciated later once proper PRNG support rolls out)

Implementing the JAX way would require all calls to random functions inside defn to pass along the stateless PRNG context to each call?

vans163 avatar Apr 02 '22 15:04 vans163

JAX PRNG reference: https://github.com/google/jax/blob/main/docs/design_notes/prng.md

josevalim avatar Apr 28 '22 09:04 josevalim