keras icon indicating copy to clipboard operation
keras copied to clipboard

current rng setup is full of footguns in jax

Open GallagherCommaJack opened this issue 1 year ago • 29 comments

right now unseeded calls to e.g. keras.random.uniform are going to acquire static seeds at trace time. this has a few undesirable consequences:

  1. subsequent calls will have the same randomness each time (e.g. dropout will have a fixed mask instead of random each step)
  2. the jax compiler cache will ~never hit, as the constant rng seed values will be different every time

to get around this, some kind of rng state management is necessary. flax does this with hierarchical management of rng's from the Scope. such an approach is fairly complex however, and there might be simpler options e.g. a single global rng state, which gets included with the training state in model.fit, unseeded rng calls would then do something along the lines of

state.seed, local_seed = jax.random.split(state.seed)

GallagherCommaJack avatar Aug 01 '23 04:08 GallagherCommaJack