Mava icon indicating copy to clipboard operation
Mava copied to clipboard

[BUG] Inconsistent Jax PRNG Seed Management

Open KaleabTessera opened this issue 2 years ago • 1 comments

Describe the bug

When we split the prng keys and generate new keys (key, subkey = jax.random.split(key)), we don't have a consistent way of using these new keys.

We do the following:

  1. Sometimes we use key immediately and keep the subkey as the new key e.g. https://github.com/instadeepai/Mava/blob/ddcd87db9ece5aa975d51420440dd63b158e8d62/mava/components/jax/building/networks.py#L69 , https://github.com/instadeepai/Mava/blob/ddcd87db9ece5aa975d51420440dd63b158e8d62/mava/components/jax/executing/action_selection.py#L100 .
  2. Sometimes we use the subkey immediately and keep key. E.g. https://github.com/instadeepai/Mava/blob/ddcd87db9ece5aa975d51420440dd63b158e8d62/mava/components/jax/training/model_updating.py#L201.
  3. We sometimes throw away one of the keys e.g. https://github.com/instadeepai/Mava/blob/ddcd87db9ece5aa975d51420440dd63b158e8d62/mava/components/jax/training/step.py#L269 .

To Reproduce

Steps to reproduce the behavior:

Expected behavior

Consistent behaviour with generating PRNG seeds.

Context (Environment)

Additional context

Possible Solution

Consistent splitting of PRNG seeds.

KaleabTessera avatar Jun 21 '22 08:06 KaleabTessera

Both keys returned are brand new keys, it doesn't necessarily matter which one is consumed and which one is kept to create further new keys. I think consistency is nice but i don't know if i would consider this a bug. Unless I've misunderstood something fundamental about Jax PRNGKeys.

EdanToledo avatar Jun 21 '22 13:06 EdanToledo