jax icon indicating copy to clipboard operation
jax copied to clipboard

RNGs: key types and custom implementations

Open froystig opened this issue 3 years ago • 4 comments

This issue tracks the introduction of typed random key arrays with pluggable PRNG implementations.

Work on this began with #6899, which mostly works out the "pluggable implementations" part, makes initial progress on the "typed" part, and sets things up throughout the codebase for backwards-compatibility and upgrading. The main configuration flag guarding the system upgrade is config.jax_enable_custom_prng.

The motivation for pluggability is straightforward: we'd like to set up various implementations for random bit generation, and still have all of jax.random work on top of any one of them. While we're at it, we could make this extensible.

The motivation for having keys reflected in types (somehow, at some level) is broadly to improve safety, and with it, to offer both users and the system some structural guarantees. Desiderata include:

  • Moving to a user-facing representation of key arrays that reflects that they are indeed key arrays. Currently key arrays are plain uint32 arrays to users, indistinguishable from any other data.
  • Restricting operations on key arrays. The current plain array representation allows for key-invalidating operations (e.g. manual updating, addition, ...). We'd like to disallow these, or at least render mistakes unlikely.
  • More/better opportunities to check key misuse, reuse, and so on.

Unsurprisingly, key types also help for mapping keys to RNG implementations (for the pluggability part above).

There are roughly two ways we've thought to approach endowing jax with key types, specifically, key-element-type arrays:

One is as a frontend component handled during staging, projected away to plain u32 in staged IR, in analogy with how pytrees behave. To introduce a pytree type in Python that wraps an underlying key-data (say, uint32) array is not quite correct, since it would misbehave under vmap, scan, and even jax.tree_map. Instead, we might want to rely on something like the typeclass mechanisms still in development (e.g. vmappable, #8451) for this approach.

The other tack is to introduce key types into our IR and other internal machinery (with a corresponding lowering), and to map to that from a Python array-of-keys-like type during staging. This might confer some extra advantages downstream as well, such as more opportunities for checking throughout our front- and middle-end. This is roughly the approach we're taking (as of around summer '22).

Ahead of the complete upgrade, some by-products are already available. We've implemented a couple of PRNGs that serve as alternatives to the default threefry2x32 implementation (in #8067, #8123) and build on compiler bit generator primitives. The process-wide default RNG implementation can be controlled via jax.default_prng_impl and config.jax_default_prng_impl (from #8135) and can be accessed via jax.random.default_prng_impl (from #9186). These can be used to swap between the pre-defined RNG alternatives for the entire process at a time.

froystig avatar Jan 20 '22 17:01 froystig

We plan on starting this migration once PRNGKeyArray is no longer a PyTree, ie. when we can register it as a JAX-type through the upcoming typeclass mechanism (like vmappable https://github.com/google/jax/pull/8451)

LenaMartens avatar Jan 21 '22 13:01 LenaMartens

Indeed, the move away from a pytree is what's blocking us. Aside from that, we also have some operations to consider implementing, as categorized nicely in @LenaMartens' description for #8381. But these don't seem to present an upfront barrier to the same degree. We believe that we can implement what's needed as we upgrade.

froystig avatar Feb 17 '22 02:02 froystig

Just curious, and didn't know where to ask. Is there a reason that k.split() does something completely different than split(k) if k is a new-style KeyArray? Could this be a trap for users? It might be convenient to have split as a member function?

NeilGirdhar avatar Sep 14 '22 14:09 NeilGirdhar

Also, is there an asymmetry between vmap and split? Currently, if you split a key array k:

h = split(split(k, 3), 5)

then h.shape is (3, 5).

But if you v-map a function that accepts a key

def f(k: KeyArray) -> Any: ...

g = vmap(vmap(f))

Then g(h) will make the outer vmap correspond to the inner split? (If I have that right?) It seems like split should prepend rather than append the dimension for consistency?

NeilGirdhar avatar Sep 14 '22 15:09 NeilGirdhar

Just curious, and didn't know where to ask. Is there a reason that k.split() does something completely different than split(k) if k is a new-style KeyArray? Could this be a trap for users?

My plan was for numpy split to be inapplicable altogether, whether as a method or via jnp.split(keys), at least for the time being. You'll notice that it's only barely available now under the upgrade flag: you can incidentally do jnp.split(k) or k.split() currently when staging under jit or similar, but the implementation will err inscrutably. This is a known temporary discrepancy, but the eager behavior is the intended one.

Do you expect a need for numpy-like splitting support?

It might be convenient to have split as a member function?

I think there isn't a solid case for this. No current code expects it, and to your point it would be ambiguous.

froystig avatar Sep 27 '22 16:09 froystig

It seems like split should prepend rather than append the dimension for consistency?

split prepends in correspondence with the very common unpacking usage, e.g.:

key1, key2, key3 = jax.random.split(key, 3)

froystig avatar Sep 27 '22 16:09 froystig

My plan was for numpy split to be inapplicable altogether,

Yes, great plan!

I was just confused by the different result. Now I see that it's a numpy function being called on the underlying array.

Do you expect a need for numpy-like splitting support?

I do not! I like the idea of making the RNG more opaque so that it can't be used as a numpy array. In particular, I love how the shape of a key array now returns the shape of the keys rather than the shape of the underlying array.

No current code expects it, and to your point it would be ambiguous

I guess it's ambiguous if you are still imagining the object as a numpy-like array. But isn't your point above that you want to get away from that?

It's not a big deal, but just to explain my reasoning in case you're curious:

The usual design pattern decision for whether something should be a member function or a free function is whether the function can be implemented through the public interface. split cannot be implemented through a key array's public interface, so at least according that criterion, it should be a member function of the object.

I agree that no code expects it, but that's because historically key arrays were numpy arrays, and so split had to be a free function. Going forward, it would be more convenient not to have to import split, and more logical to make it a member function.

This argument doesn't apply to functions like normal, gamma, etc. These are implemented through the public interface (random_bits), so they are rightly free functions.

Anyway, no big deal :smile:

NeilGirdhar avatar Sep 27 '22 20:09 NeilGirdhar

I see. Thanks for spelling that out.

Although key arrays are not entirely numpy-like, they are still partly so. For example, they support transposition, reshaping, and a few others, and they offer these via methods as well as jax.numpy calls. So long as any amount of numpy is supported, I'm wary of possible ambiguity.

I otherwise appreciate your case in favor of it. We can start with less for now and see whether we're drawn to add it.

froystig avatar Sep 28 '22 01:09 froystig

So long as any amount of numpy is supported, I'm wary of possible ambiguity.

Yes, very good point.

I otherwise appreciate your case in favor of it. We can start with less for now and see whether we're drawn to add it.

Sounds great. And I love the new RNG interface if it wasn't obvious 😄

NeilGirdhar avatar Sep 28 '22 02:09 NeilGirdhar

6abefa197776994926f5d5330fe994f94f0090dc makes dispatch fast again for functions over new-style RNG keys!

froystig avatar Sep 13 '23 17:09 froystig

I guess it's ambiguous if you are still imagining the object as a numpy-like array. [...]

Returning to this old thread with @NeilGirdhar – we've switched to thinking about this in a way that is hopefully more straightforward: key arrays are now like any other jax array from the user's point of view. They just have a different dtype.

The dtype determines what operations are allowed on the array. Element-type-polymorphic operations like transposition and slicing are fine. Addition is not, since the element type doesn't support it. Also, there is no longer a user-visible type distinction for the enclosing array. To check for whether an array is a key array, we recommend issubdtype.

This is all covered by the JEP drafted in #17297. I just thought to highlight it!

froystig avatar Sep 14 '23 22:09 froystig

@froystig Thanks for linking me.

Sounds like an excellent design!

NeilGirdhar avatar Sep 15 '23 00:09 NeilGirdhar