Mava
Mava copied to clipboard
[BUG] Inconsistent Jax PRNG Seed Management
Describe the bug
When we split the prng keys and generate new keys (key, subkey = jax.random.split(key)
), we don't have a consistent way of using these new keys.
We do the following:
- Sometimes we use key immediately and keep the subkey as the new key e.g. https://github.com/instadeepai/Mava/blob/ddcd87db9ece5aa975d51420440dd63b158e8d62/mava/components/jax/building/networks.py#L69 , https://github.com/instadeepai/Mava/blob/ddcd87db9ece5aa975d51420440dd63b158e8d62/mava/components/jax/executing/action_selection.py#L100 .
- Sometimes we use the subkey immediately and keep key. E.g. https://github.com/instadeepai/Mava/blob/ddcd87db9ece5aa975d51420440dd63b158e8d62/mava/components/jax/training/model_updating.py#L201.
- We sometimes throw away one of the keys e.g. https://github.com/instadeepai/Mava/blob/ddcd87db9ece5aa975d51420440dd63b158e8d62/mava/components/jax/training/step.py#L269 .
To Reproduce
Steps to reproduce the behavior:
Expected behavior
Consistent behaviour with generating PRNG seeds.
Context (Environment)
Additional context
Possible Solution
Consistent splitting of PRNG seeds.
Both keys returned are brand new keys, it doesn't necessarily matter which one is consumed and which one is kept to create further new keys. I think consistency is nice but i don't know if i would consider this a bug. Unless I've misunderstood something fundamental about Jax PRNGKeys.