flax icon indicating copy to clipboard operation
flax copied to clipboard

LazyRNG can accidentally have key sharing across layers.

Open levskaya opened this issue 2 years ago • 2 comments

User reported repo:

import jax
from flax import linen as nn

class Leaf(nn.Module):
  def __call__(self, x):
    return x + jax.random.randint(self.make_rng("rng"), (), 0, 100)

class Node(nn.Module):
  leaf_name: str
  @nn.compact
  def __call__(self, x):
    return Leaf(name=self.leaf_name)(x)

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    return (Node(name="ab", leaf_name="cdef")(x),
            Node(name="abc", leaf_name="def")(x),
    )

print(Model().apply({}, 0, rngs={"rng": jax.random.PRNGKey(33)})
# (DeviceArray(23, dtype=int32), DeviceArray(23, dtype=int32))

This occurs because we don't add a separator at: https://github.com/google/flax/blob/main/flax/core/scope.py#L102

levskaya avatar May 26 '22 23:05 levskaya