flax icon indicating copy to clipboard operation
flax copied to clipboard

Force no split in `make_rng`

Open cgarciae opened this issue 1 year ago • 5 comments

Discussed in https://github.com/google/flax/discussions/3113

Originally posted by zaccharieramzi May 24, 2023 I have the following situation: I am using a Dropout layer multiple times without a nn.scan or nn.while_loop, therefore I cannot use split_rngs={"dropout": False}. However, I would still like to use the same dropout mask twice.

Is it possible to specify "no split" to make rng for certain collections?

If I just take the original dropout example I would like to do something like:

# Setup.
import jax
import jax.numpy as jnp
import flax.linen as nn

# Randomness.
seed = 0
root_key = jax.random.PRNGKey(seed=seed)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)

# A simple network.
class MyModel(nn.Module):
  num_neurons: int
  training: bool
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.num_neurons)(x)
    # Set the dropout layer with a rate of 50% .
    # When the `deterministic` flag is `True`, dropout is turned off.
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)
    return x

# Instantiate `MyModel` (you don't need to set `training=True` to
# avoid performing the forward pass computation).
my_model = MyModel(num_neurons=3, training=False)

x = jax.random.uniform(key=main_key, shape=(3, 4, 4))

# Initialize with `flax.linen.init()`.
# The `params_key` is equivalent to a dictionary of PRNGs.
# (Here, you are providing only one PRNG key.) 
variables = my_model.init(params_key, x)

# Perform the forward pass with `flax.linen.apply()`.
my_model.training = True
y = my_model.apply(variables, x, rngs={'dropout': dropout_key})

and still have jnp.sum(y == 0.) / (3*4*3) == 0.5 approx.

For more context I am actually trying to implement Deep Equilibrium Models using jaxopt and flax, where the fixed point defining function uses dropout. I also tried to see if the split_rngs functionality could be extended to jaxopt but I think it's going to be difficult.

cgarciae avatar May 24 '23 14:05 cgarciae

Hey @zaccharieramzi, I've converted the discussion into and issue as it seems something that we should improve. I've created #3114, which would allow you to optionally specify the rng key for each Dropout layer, e.g:

# A simple network.
class MyModel(nn.Module):
  num_neurons: int
  training: bool
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.num_neurons)(x)
    # Set the dropout layer with a rate of 50% .
    # When the `deterministic` flag is `True`, dropout is turned off.
    key = self.make_rng('dropout')
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x, rng=key)
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x, rng=key)
    return 

This way both layers will produce the same mask.

cgarciae avatar May 24 '23 14:05 cgarciae

Would there be a way to propagate this information rather than having to pass it around to each dropout? Indeed, in my case I would need to do key = self.make_rng("dropout") and pass it down to the actual dropout layers which are nested deep in different nn.Modules.

Something like:

key = self.make_rng('dropout')
x = MyModule(...)(x, rng=key)

where originally MyModule does not have the rng parameter in its API.

zaccharieramzi avatar May 24 '23 15:05 zaccharieramzi

ofc I understand it might be way more complex to do, so it's really just a question

zaccharieramzi avatar May 24 '23 15:05 zaccharieramzi

@cgarciae I see that this was closed so maybe you missed my earlier question. Typically in modules like dot_product_attention the dropout is hardcoded without the possibility to set the rng. Do you think it's best then to reimplement all these modules with the possibility to pass the rng?

zaccharieramzi avatar May 25 '23 08:05 zaccharieramzi

FYI @zaccharieramzi, I added a dropout_arg to nn.MultiHeadDotProductAttention in #3384 so you can get the same dropout mask

chiamp avatar Nov 01 '23 15:11 chiamp