jax
jax copied to clipboard
RNGs: key types and custom implementations
This issue tracks the introduction of typed random key arrays with pluggable PRNG implementations.
Work on this began with #6899, which mostly works out the "pluggable implementations" part, makes initial progress on the "typed" part, and sets things up throughout the codebase for backwards-compatibility and upgrading. The main configuration flag guarding the system upgrade is config.jax_enable_custom_prng
.
The motivation for pluggability is straightforward: we'd like to set up various implementations for random bit generation, and still have all of jax.random
work on top of any one of them. While we're at it, we could make this extensible.
The motivation for having keys reflected in types (somehow, at some level) is broadly to improve safety, and with it, to offer both users and the system some structural guarantees. Desiderata include:
- Moving to a user-facing representation of key arrays that reflects that they are indeed key arrays. Currently key arrays are plain
uint32
arrays to users, indistinguishable from any other data. - Restricting operations on key arrays. The current plain array representation allows for key-invalidating operations (e.g. manual updating, addition, ...). We'd like to disallow these, or at least render mistakes unlikely.
- More/better opportunities to check key misuse, reuse, and so on.
Unsurprisingly, key types also help for mapping keys to RNG implementations (for the pluggability part above).
There are roughly two ways we've thought to approach endowing jax with key types, specifically, key-element-type arrays:
One is as a frontend component handled during staging, projected away to plain u32
in staged IR, in analogy with how pytrees behave. To introduce a pytree type in Python that wraps an underlying key-data (say, uint32
) array is not quite correct, since it would misbehave under vmap
, scan
, and even jax.tree_map
. Instead, we might want to rely on something like the typeclass mechanisms still in development (e.g. vmappable
, #8451) for this approach.
The other tack is to introduce key types into our IR and other internal machinery (with a corresponding lowering), and to map to that from a Python array-of-keys-like type during staging. This might confer some extra advantages downstream as well, such as more opportunities for checking throughout our front- and middle-end. This is roughly the approach we're taking (as of around summer '22).
Ahead of the complete upgrade, some by-products are already available. We've implemented a couple of PRNGs that serve as alternatives to the default threefry2x32
implementation (in #8067, #8123) and build on compiler bit generator primitives. The process-wide default RNG implementation can be controlled via jax.default_prng_impl
and config.jax_default_prng_impl
(from #8135) and can be accessed via jax.random.default_prng_impl
(from #9186). These can be used to swap between the pre-defined RNG alternatives for the entire process at a time.
We plan on starting this migration once PRNGKeyArray
is no longer a PyTree, ie. when we can register it as a JAX-type through the upcoming typeclass mechanism (like vmappable
https://github.com/google/jax/pull/8451)
Indeed, the move away from a pytree is what's blocking us. Aside from that, we also have some operations to consider implementing, as categorized nicely in @LenaMartens' description for #8381. But these don't seem to present an upfront barrier to the same degree. We believe that we can implement what's needed as we upgrade.
Just curious, and didn't know where to ask. Is there a reason that k.split()
does something completely different than split(k)
if k
is a new-style KeyArray
? Could this be a trap for users? It might be convenient to have split
as a member function?
Also, is there an asymmetry between vmap
and split
? Currently, if you split a key array k
:
h = split(split(k, 3), 5)
then h.shape
is (3, 5)
.
But if you v-map a function that accepts a key
def f(k: KeyArray) -> Any: ...
g = vmap(vmap(f))
Then g(h)
will make the outer vmap
correspond to the inner split
? (If I have that right?) It seems like split should prepend rather than append the dimension for consistency?
Just curious, and didn't know where to ask. Is there a reason that
k.split()
does something completely different thansplit(k)
ifk
is a new-styleKeyArray
? Could this be a trap for users?
My plan was for numpy split
to be inapplicable altogether, whether as a method or via jnp.split(keys)
, at least for the time being. You'll notice that it's only barely available now under the upgrade flag: you can incidentally do jnp.split(k)
or k.split()
currently when staging under jit
or similar, but the implementation will err inscrutably. This is a known temporary discrepancy, but the eager behavior is the intended one.
Do you expect a need for numpy-like splitting support?
It might be convenient to have
split
as a member function?
I think there isn't a solid case for this. No current code expects it, and to your point it would be ambiguous.
It seems like split should prepend rather than append the dimension for consistency?
split
prepends in correspondence with the very common unpacking usage, e.g.:
key1, key2, key3 = jax.random.split(key, 3)
My plan was for numpy split to be inapplicable altogether,
Yes, great plan!
I was just confused by the different result. Now I see that it's a numpy function being called on the underlying array.
Do you expect a need for numpy-like splitting support?
I do not! I like the idea of making the RNG more opaque so that it can't be used as a numpy array. In particular, I love how the shape of a key array now returns the shape of the keys rather than the shape of the underlying array.
No current code expects it, and to your point it would be ambiguous
I guess it's ambiguous if you are still imagining the object as a numpy-like array. But isn't your point above that you want to get away from that?
It's not a big deal, but just to explain my reasoning in case you're curious:
The usual design pattern decision for whether something should be a member function or a free function is whether the function can be implemented through the public interface. split
cannot be implemented through a key array's public interface, so at least according that criterion, it should be a member function of the object.
I agree that no code expects it, but that's because historically key arrays were numpy arrays, and so split
had to be a free function. Going forward, it would be more convenient not to have to import split, and more logical to make it a member function.
This argument doesn't apply to functions like normal
, gamma
, etc. These are implemented through the public interface (random_bits
), so they are rightly free functions.
Anyway, no big deal :smile:
I see. Thanks for spelling that out.
Although key arrays are not entirely numpy-like, they are still partly so. For example, they support transposition, reshaping, and a few others, and they offer these via methods as well as jax.numpy
calls. So long as any amount of numpy is supported, I'm wary of possible ambiguity.
I otherwise appreciate your case in favor of it. We can start with less for now and see whether we're drawn to add it.
So long as any amount of numpy is supported, I'm wary of possible ambiguity.
Yes, very good point.
I otherwise appreciate your case in favor of it. We can start with less for now and see whether we're drawn to add it.
Sounds great. And I love the new RNG interface if it wasn't obvious 😄
6abefa197776994926f5d5330fe994f94f0090dc makes dispatch fast again for functions over new-style RNG keys!
I guess it's ambiguous if you are still imagining the object as a numpy-like array. [...]
Returning to this old thread with @NeilGirdhar – we've switched to thinking about this in a way that is hopefully more straightforward: key arrays are now like any other jax array from the user's point of view. They just have a different dtype.
The dtype determines what operations are allowed on the array. Element-type-polymorphic operations like transposition and slicing are fine. Addition is not, since the element type doesn't support it. Also, there is no longer a user-visible type distinction for the enclosing array. To check for whether an array is a key array, we recommend issubdtype
.
This is all covered by the JEP drafted in #17297. I just thought to highlight it!
@froystig Thanks for linking me.
Sounds like an excellent design!