nx
nx copied to clipboard
Backend-agnostic stateless PRNGs
As discussed in #331, we need an approach for having portable PRNGs that can be applied for any backend.
We definitely need the ability to seed RNGs, and I'm confident this approach will work for Torchx and other backends, but I'm not 100% sure how it will extend to EXLA and other Tensor compilers that can't depend on a stateful RNG. As an example, we can set the RNG seed as an executable run option in EXLA, but that would apply to the entire execution and not individual calls to
random_uniform
withindefn
. It would also need to be passed as a compile option rather than directly to calls torandom_x
.I think it probably is best to move forward instead with Jax-style stateless PRNGs because we can implement them with our current API and have a solution that extends to every compiler and backend. The JAX PRNGs also tout themselves as perfect for distribution/parallel computing, and I think that aligns with some of our future goals. In order to rework this from the EXLA perspective, it would probably involve using
RngBitGenerator
and reimplementing our current random functions in terms of that and other primitives.
Originally posted by @seanmor5 in https://github.com/elixir-nx/nx/issues/331#issuecomment-798775291
Leaving some notes here
JAX PRNG https://github.com/google/jax/blob/537e35b0fa2c2126cdd22a2e346b65ce11395f80/jax/_src/prng.py
uses ThreeFry https://bashtage.github.io/randomgen/bit_generators/threefry.html which allows parallel application use cases via its key system
XLA RngBitGenerator implementation https://github.com/google/jax/blob/537e35b0fa2c2126cdd22a2e346b65ce11395f80/jax/_src/prng.py#L543
https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator
If we set the RngSeed in XLA, https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor.cpp#L346, would this apply to all future kernels as well (as it seems to be set on the DeviceCtx)? Would future kernels reset seed to 0 if the context does not restart? (Just thinking of a temporary quick stopgap that can be depreciated later once proper PRNG support rolls out)
Implementing the JAX way would require all calls to random functions inside defn
to pass along the stateless PRNG context to each call?
JAX PRNG reference: https://github.com/google/jax/blob/main/docs/design_notes/prng.md