jaxnet
jaxnet copied to clipboard
PRNG handling akin to parameters
Handling parameters in JAX can get annoying, but what really concerns me even more is handling PRNG keys. JAX has a done a lot of great work to build a very strong PRNG system, but unfortunately splitting and managing random keys can be very messy and especially error-prone. It's alarmingly easy to accidentally reuse a PRNG key. It would be great to have a system analogous to @parameterized
and parameter()
but for random keys and seeds.
I envision an API providing something like @random
and rng()
:
@random
def my_func(x):
W = jax.random.normal(rng(), shape=(2, 2))
b = jax.random.exponential(rng(), shape=(2,))
return W @ x + b
And then ~ magic ~ happens after which point we get a function like:
def my_func(x, rng=None):
rng0, rng = jax.random.split(rng)
W = jax.random.normal(rng0, shape=(2, 2))
rng1, rng = jax.random.split(rng)
b = jax.random.exponential(rng1, shape=(2,))
return W @ x + b
You can already userandom_key()
within @parametrized
:
@parametrized
def dropout(inputs):
keep_rate = 1 - rate
keep = random.bernoulli(random_key(), keep_rate, inputs.shape)
return np.where(keep, inputs / keep_rate, 0)
An independent seed
transform as you describe would make sense, if I find time I will factor it out.
Neat, was not aware of that! Yeah I think having a separate transform would be great.