flax
flax copied to clipboard
Force no split in `make_rng`
Discussed in https://github.com/google/flax/discussions/3113
Originally posted by zaccharieramzi May 24, 2023
I have the following situation: I am using a Dropout
layer multiple times without a nn.scan
or nn.while_loop
, therefore I cannot use split_rngs={"dropout": False}
.
However, I would still like to use the same dropout mask twice.
Is it possible to specify "no split" to make rng for certain collections?
If I just take the original dropout example I would like to do something like:
# Setup.
import jax
import jax.numpy as jnp
import flax.linen as nn
# Randomness.
seed = 0
root_key = jax.random.PRNGKey(seed=seed)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
# A simple network.
class MyModel(nn.Module):
num_neurons: int
training: bool
@nn.compact
def __call__(self, x):
x = nn.Dense(self.num_neurons)(x)
# Set the dropout layer with a rate of 50% .
# When the `deterministic` flag is `True`, dropout is turned off.
x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)
x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)
return x
# Instantiate `MyModel` (you don't need to set `training=True` to
# avoid performing the forward pass computation).
my_model = MyModel(num_neurons=3, training=False)
x = jax.random.uniform(key=main_key, shape=(3, 4, 4))
# Initialize with `flax.linen.init()`.
# The `params_key` is equivalent to a dictionary of PRNGs.
# (Here, you are providing only one PRNG key.)
variables = my_model.init(params_key, x)
# Perform the forward pass with `flax.linen.apply()`.
my_model.training = True
y = my_model.apply(variables, x, rngs={'dropout': dropout_key})
and still have jnp.sum(y == 0.) / (3*4*3) == 0.5
approx.
For more context I am actually trying to implement Deep Equilibrium Models using jaxopt
and flax
, where the fixed point defining function uses dropout.
I also tried to see if the split_rngs
functionality could be extended to jaxopt
but I think it's going to be difficult.
Hey @zaccharieramzi, I've converted the discussion into and issue as it seems something that we should improve.
I've created #3114, which would allow you to optionally specify the rng
key for each Dropout
layer, e.g:
# A simple network.
class MyModel(nn.Module):
num_neurons: int
training: bool
@nn.compact
def __call__(self, x):
x = nn.Dense(self.num_neurons)(x)
# Set the dropout layer with a rate of 50% .
# When the `deterministic` flag is `True`, dropout is turned off.
key = self.make_rng('dropout')
x = nn.Dropout(rate=0.5, deterministic=not self.training)(x, rng=key)
x = nn.Dropout(rate=0.5, deterministic=not self.training)(x, rng=key)
return
This way both layers will produce the same mask.
Would there be a way to propagate this information rather than having to pass it around to each dropout?
Indeed, in my case I would need to do key = self.make_rng("dropout")
and pass it down to the actual dropout layers which are nested deep in different nn.Module
s.
Something like:
key = self.make_rng('dropout')
x = MyModule(...)(x, rng=key)
where originally MyModule
does not have the rng
parameter in its API.
ofc I understand it might be way more complex to do, so it's really just a question
@cgarciae I see that this was closed so maybe you missed my earlier question. Typically in modules like dot_product_attention
the dropout is hardcoded without the possibility to set the rng.
Do you think it's best then to reimplement all these modules with the possibility to pass the rng?
FYI @zaccharieramzi, I added a dropout_arg
to nn.MultiHeadDotProductAttention
in #3384 so you can get the same dropout mask