jaxnet icon indicating copy to clipboard operation
jaxnet copied to clipboard

PRNG handling akin to parameters

Open samuela opened this issue 4 years ago • 2 comments

Handling parameters in JAX can get annoying, but what really concerns me even more is handling PRNG keys. JAX has a done a lot of great work to build a very strong PRNG system, but unfortunately splitting and managing random keys can be very messy and especially error-prone. It's alarmingly easy to accidentally reuse a PRNG key. It would be great to have a system analogous to @parameterized and parameter() but for random keys and seeds.

I envision an API providing something like @random and rng():

@random
def my_func(x):
  W = jax.random.normal(rng(), shape=(2, 2))
  b = jax.random.exponential(rng(), shape=(2,))
  return W @ x + b

And then ~ magic ~ happens after which point we get a function like:

def my_func(x, rng=None):
  rng0, rng = jax.random.split(rng)
  W = jax.random.normal(rng0, shape=(2, 2))
  rng1, rng = jax.random.split(rng)
  b = jax.random.exponential(rng1, shape=(2,))
  return W @ x + b

samuela avatar Apr 08 '20 00:04 samuela

You can already userandom_key() within @parametrized:

@parametrized
def dropout(inputs):
        keep_rate = 1 - rate
        keep = random.bernoulli(random_key(), keep_rate, inputs.shape)
        return np.where(keep, inputs / keep_rate, 0)

An independent seed transform as you describe would make sense, if I find time I will factor it out.

juliuskunze avatar Apr 09 '20 08:04 juliuskunze

Neat, was not aware of that! Yeah I think having a separate transform would be great.

samuela avatar Apr 09 '20 16:04 samuela