keras
keras copied to clipboard
current rng setup is full of footguns in jax
right now unseeded calls to e.g. keras.random.uniform
are going to acquire static seeds at trace time. this has a few undesirable consequences:
- subsequent calls will have the same randomness each time (e.g. dropout will have a fixed mask instead of random each step)
- the jax compiler cache will ~never hit, as the constant rng seed values will be different every time
to get around this, some kind of rng state management is necessary. flax does this with hierarchical management of rng's from the Scope
. such an approach is fairly complex however, and there might be simpler options e.g. a single global rng
state, which gets included with the training state in model.fit
, unseeded rng calls would then do something along the lines of
state.seed, local_seed = jax.random.split(state.seed)