flax
flax copied to clipboard
LazyRNG can accidentally have key sharing across layers.
User reported repo:
import jax
from flax import linen as nn
class Leaf(nn.Module):
def __call__(self, x):
return x + jax.random.randint(self.make_rng("rng"), (), 0, 100)
class Node(nn.Module):
leaf_name: str
@nn.compact
def __call__(self, x):
return Leaf(name=self.leaf_name)(x)
class Model(nn.Module):
@nn.compact
def __call__(self, x):
return (Node(name="ab", leaf_name="cdef")(x),
Node(name="abc", leaf_name="def")(x),
)
print(Model().apply({}, 0, rngs={"rng": jax.random.PRNGKey(33)})
# (DeviceArray(23, dtype=int32), DeviceArray(23, dtype=int32))
This occurs because we don't add a separator at: https://github.com/google/flax/blob/main/flax/core/scope.py#L102