RNN FLIP
RNN Flip
- Start Date: 2022-08-18
- FLIP PR: N/A
- FLIP Issue: #2396
- Authors: Jasmijn Bastings (@bastings) and Cristian Garcia (@cgarciae)
Summary
This FLIP proposes a comprehensive set of changes and additions to Flax in order to improve its support for recurrent architectures. This includes updates to the RNNCellBase API and a new RNNBase API.
Motivation
Implementing well known recurrent architectures is currently challenging in Flax as it is missing some features that are not easy to implement, prone to user errors, and should ideally be encapsulated by new abstractions within Flax. This will provide our users with clean, correct, and efficient abstractions to use RNN cells.
Requirements:
- Masking: We need to support a batch of sequences that contain padding at the end of each sequence. We do not intend to support non-contiguous padding, i.e. padding that is not at the end of a sequence, for performance reasons.
- Recurrent Dropout: Support for recurrent dropout in cells (e.g. dropout on the state of the cell).
- Bidirectionality: The ability to process a sequence in both the forward and reverse directions, respecting padding.
- Performance: The proposed classes should be benchmarked to provide the best performance in terms of step time and/or memory use.
Implementation
High-level structure
To satisfy the previous requirements, this proposal includes some updates to the existing RNNCellBase API and expands it to include higher-level APIs that can apply cells to full sequences and handle padding.
Concretely, we propose to have these 3 levels of abstraction:
- Cells: all RNNCellBase subclasses such as LSTMCell and GRUCell, these implement the stepwise logic. These already exist in Flax today.
- Layers: a class (RNN) or possibly classes (RNN, GRU, LSTM) that take a cell and scan over a sequence respecting possible padding values
- Bidirectional: a single class that takes a forward and a backward RNN instance and correctly processes the input sequence in both directions and merges the results.
Example
We start with a code example of what you could do with the proposed API.
cell = nn.LSTMCell(features=32) # Now accepts features.
# Encodes a batch of input sequences.
final_carry, outputs = nn.RNN(cell)(inputs, segmentation_mask)
A Bidirectional layer with a LSTM and GRU RNNs for the forward and backward directions respectively would look like this:
forward_rnn = nn.RNN(nn.LSTMCell(features=32))
backward_rnn = nn.RNN(nn.GRUCell(features=32))
# Bidirectional combinator.
bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn)
# Encodes a batch of input sequences in both directions.
carry, outputs = bi_rnn(inputs, segmentation_mask)
Next we will discuss RNN (RNNBase), Bidirectional, and proposed changes to RNNCellBase.
RNNBase
The proposed RNNBase Module is responsible for applying a single RNNCellBase instance over a batch of input sequences, as its name suggests it would serve as the base class for all RNN layers. While a family of RNN layers like LSTM or GRU can be implemented on top of RNNBase, in this proposal we will just discuss the implementation of a generic RNN class that handles arbitrary cells.
The main feature that RNNBase brings is that it contains an rnn_forward method that has all the necessary logic to apply a cell over a sequence using nn.scan while handling padding.
Inheritance Strategy
We will first discuss how to actually structure the RNNBase class. Since Linen Modules are dataclasses, figuring out how to enable inheritance and design the API overall becomes a challenge. We will discuss three options:
- An Object Oriented approach
- A mostly Functional approach
- A pure function approach
We investigate options 1 and 2 in this rnn_base_prototype colab.
Option 1: Object Oriented approach
The object oriented approach tries to have a couple of parameters of the computation available as fields of the object (e.g. cell) and other as inputs to rnn_forward (e.g. inputs):
class RNNBase(nn.Module):
cell: RNNCellBase
time_axis: int = -2
def rnn_forward(
self, inputs: Array,
initial_carry: Optional[Carry], segmentation_mask: Optional[Array],
reverse: bool,
**scan_kwargs) -> Tuple[Carry, Output]:
...
The downside of this approach is that due to constraints with how dataclasses take class fields and interpret them as constructor arguments, it makes it really hard or even impossible to express certain constructor patterns. For this strategy to work one of the following is required:
- Completely redefine how the dataclass processes class fields to define the construct by grouping required fields at the beginning and optional arguments at the end, and allowing for shadowing to enable overloading.
- Allow for the definition of a custom
__init__and then add a newModule.initialize_dataclassthat the user can manually all at the end of__init__to execute some additional logic required by dataclasses such as calling__post_init__.
Neither of this is ideal but option 2 is easier to implement.
Option 2: Mostly Functional approach
This mostly functional approach still uses class-based inheritance but the base class has no fields (therefore no dataclass issues) but all the parameters to rnn_forward now have to be explicitly passed as arguments so using it is more verbose.
class RNNBase(nn.Module):
def rnn_forward(
self, cell: RNNCellBase,
inputs: Array, time_axis: int,
initial_carry: Optional[Carry], segmentation_mask: Optional[Array],
reverse: bool,
**scan_kwargs) -> Tuple[Carry, Output]:
...
In this example cell and time_axis are now passed as additional arguments to rnn_forward whereas previously they were passed through self.
Option 3: Pure function approach
As a final option, we could avoid classes altogether and just provide an rnn_forward function, this has the following downsides:
Having access to self is useful in case we want to create variables or use any of the methods from Module associated with our class.
rnn_forward would become harder to use as you have to import it.
If the API grew beyond rnn_forward you would need to keep track of multiple functions external to your class.
All things considered, we are going with option 2 as (while not perfect) it has the least amount of drawbacks and doesn’t require any additional implementations for it to work.
rnn_forward
Now we will discuss the main API which consists of the rnn_forward method (looking for a better name) that is tasked with using nn.scan to apply a cell over a sequence while taking care of the padding logic and other conveniences such as processing the sequence in the reverse order. The proposed signature is the following:
def rnn_forward(
self, cell: RNNCellBase,
inputs: Array, time_axis: int,
initial_carry: Optional[Carry], segmentation_mask: Optional[Array],
reverse: bool,
**scan_kwargs) -> Tuple[Carry, Output]:
...
The output of this method would be the final carry and the output sequence. Based on this, subclasses of RNNBase would define __call__ and internally use rnn_forward according to their need.
RNN
The first and probably most important implementation based on RNNBase will be the RNN class. RNN is a very generic Module that can handle arbitrary cells and would mostly behave as a wrapper over the rnn_forward method. RNNs API could be as simple as this:
class RNN(RNNBase):
cell: RNNCellBase
time_axis: int = -2
unroll: int = 1
# scan kwargs
variable_axes = FrozenDict()
variable_broadcast: CollectionFilter = 'params'
variable_carry: CollectionFilter = False
split_rngs = FrozenDict({'params': False})
def __call__(
self, inputs, segmentation_mask: Optional[Array] = None,
initial_carry: Optional[Carry] = None,
) -> Tuple[Carry, Output]:
...
The default values chosen for scan’s keyword arguments such as variable_broadcast are defined such that they work when using RNN with most common cells like LSTMCell.
Masking
One detail remaining regarding the RNNBase API is the masking strategy, here are some available strategies used in other frameworks:
- Binary masking: specifies per-sample and timestep whether that data point should be included or not in the computation. In general binary masks can be non-contiguous (e.g., [1, 1, 0, 1]) which has implications for the code dealing with them. Keras uses this style.
- Sequence length masking: specifies per-sample how many timesteps should be included in the computation, this assumes all the included examples are in the beginning of the sequence and all padding is at the end.
- Sequence packing: allows “packing” multiple samples per sequence, an integer segmentation mask is given to differentiate between examples and also specifies padding (e.g. [1, 1, 1, 2, 2, 0, 0]). This strategy potentially reduces the amount of padding needed so it's more efficient in terms of computation and memory, however it introduces some additional challenges we will discuss next. Pytorch uses this representation (see pack_padded_sequence).
Sequence packing (see LM1B example) offers a more general representation than binary masking and is potentially more efficient than sequence length masking, however, because each batch now holds a variable number of samples it becomes more challenging given JAX’s static shape requirements to perform tasks such as extracting the last output/carry for each sample. Despise these challenges, the we propose to implement sequence packing given its generality and computational benefits.
Bidirectional
Bidirectional processing can be achieved via a Module that accepts a forward_rnn Module and a reverse_rnn Module, both of which should be RNNBase instances, in order to process the input sequence in both directions. Here we present some pseudo code of the implementation:
def __call__(self, inputs, segmentation_mask):
# Encode in the forward direction.
forward_carry, forward_outputs = self.forward_rnn(inputs, segmentation_mask)
# Flip the backward outputs to match original order
backward_inputs = flip_sequences(inputs, segmentation_mask)
# Encode in the reverse order.
backward_carry, backward_outputs = self.reverse_rnn(
backward_inputs, segmentation_mask)
# Flip the backward outputs to match original order
# taking padding into account.
backward_outputs = flip_sequences(backward_outputs, segmentation_mask)
# Merge both sequences.
outputs = self.merge_fn(forward_outputs, backward_outputs)
return (forward_carry, backward_carry), outputs
Here flip_sequences is a function that can flip a sequence while leaving the padded values at the end, and merge_fn a function that takes both outputs and fuses them (concat by default). As showcased in the beginning of this document, usage would look like this:
forward_rnn = nn.RNN(nn.LSTMCell(features=32))
backward_rnn = nn.RNN(nn.GRUCell(features=32))
# Bidirectional combinator.
bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn)
# Encodes a batch of input sequences in both directions.
carry, outputs = bi_rnn(inputs, segmentation_mask)
RNNCellBase redesign
The RNNCellBase API can be improved in two ways:
- It has no support for recurrent dropout nor mechanisms to help implement in an efficient way.
- The current
initialize_carryclass method leaks some implementation details onto the user which has to manually specify the batch dimension and the size of the carry.
Make initialize_carry an instance method
Currently initialize_carry is a class method, because of this certain implementation details are leaked outside the cell and information such as batch dimensions and carry size has to be provided by the user instead of being inferred by the cell. The solution would be to make initialize_carry an instance method so the object has enough information to infer most of the relevant shape information, this change also implies the addition of new fields such as features to cells like LSTMCell / GRUCell. Note that this update is not strictly required and it is a breaking change but would serve as a Quality of Life update to the API.
We propose the following signature:
def initialize_carry(self, sample_inputs: Array) -> Carry:
...
Where sample_inputs is an array with the same shape as the cell’s inputs in a single timestep (no time dimension). Since most of the required metadata would be available as attributes or could be gathered from methods make_rng, the user experience is highly simplified.
For example, typical usage of an LSTMCell currently looks like this:
features = 32
inputs = jnp.ones((16, 10, 8))
cell = nn.LSTMCell()
carry = nn.ConvLSTM.initialize_carry(
jax.random.PRNGKey(0), inputs.shape[:1], features)
Whereas with the proposed change it would be simplified to:
inputs = jnp.ones((16, 10, 8))
cell = nn.LSTMCell(features=32)
carry = cell.initialize_carry(inputs[:, 0])
Main difference being that the user didn’t have to provide an RNG nor calculate a lot of the shape information, this change would be even more evident in cells like nn.ConvLSTM in which defining the size argument is more involved for the user.
New create_mask method
The second modification to the RNNCellBase API would be to add a new create_mask to support recurrent dropout. In RNNs Dropout can be applied to both the inputs, which we will refer as input dropout, and to the recurrent states (e.g., (c, h) depending on the cell type), which we will refer to as recurrent dropout. (See e.g., [1603.05118] Recurrent Dropout without Memory Loss). There are a variety of ways to implement these dropout operations which we will analyze next:
Input dropout
- Sample from
nn.Dropoutat each time step (slow). - Sample masks for all time steps in advance.
- Don’t handle input dropout at all, users should just apply nn.Dropout to the input in advance outside scan. If we choose this then the cell only implements recurrent dropout, since that needs to be done inside the cell, but not input/output dropout, which can be done outside the cell.
Recurrent dropout
- Sample from
nn.Dropoutat each time step while freezing (not splitting) the RNG (slow and currently not supported, would require #2194). - Sample masks for all time steps in advance. In most cases the same mask would be the same (replicated) for all timesteps.
References:
A simple benchmark shows that precomputing masks and passing them to each step (option 2) is faster than precomputing RNGs and calling functions like jax.random.bernoulli inside scan (option 1).
PLOT
Because of this, we propose the addition of the RNNCellBase.create_mask method such that RNNBase can support recurrent dropout for arbitrary cells. The proposed signature for create_mask would be the following:
def create_mask(self, sample_inputs: Array, steps: int) -> Mask
...
Here the returned mask should be a pytree where its leaves have a leading dimension of size steps. This mask should be optionally accepted by the cells __call__ method which would now have the following signature:
def __call__(self, carry, inputs, mask=None) -> Tuple[Carry, Output]:
...
Note that for certain cells you technically only require the mask for the recurrent part (applied to c) since input dropout could be applied outside the cell, however we propose to keep this API as general as possible.
Recurrent Dropout
Currently nn.Dropout does 2 main operations: creating a random mask and applying such mask to the inputs using the dropout scaling formula.
# code from Dropout.__call__ in stochastic.py
rng = self.make_rng('dropout') # create_mask
broadcast_shape = list(inputs.shape) # create_mask
for dim in self.broadcast_dims: # create_mask
broadcast_shape[dim] = 1 # create_mask
mask = random.bernoulli( # create_mask
rng, p=keep_prob, shape=broadcast_shape) # create_mask
mask = jnp.broadcast_to(mask, inputs.shape) # create_mask
return lax.select( # apply_mask
mask, inputs / keep_prob, jnp.zeros_like(inputs)) # apply_mask
When implementing recurrent Dropout you still need to do this logic but, as we previously discussed, for performance reasons the masks have to be created for all timesteps ahead of time and the application of each mask has to be performed later on inside scan. To make it easier for users to perform these steps while taking into account best practices, a solution could be to refactor the previous code into two methods, create_mask and apply_mask with the following signatures:
class Dropout(nn.Module):
def create_mask(self, shape: Tuple[int, ...]) -> Array:
...
def apply_mask(self, inputs: Array, mask: Array) -> Array:
...
Without breaking any code, Dropout.__call__’s implementation could be simplified using these methods and then they would also then be available to facilitate the implementation of recurrent dropout and similar use cases.
LSTM Example
Putting it all together, cells like LSTMCell can be redefined to accept a feature argument as is common in most implementations:
class LSTMCell(RNNCellBase):
features: int
gate_fn: Callable[..., Any] = sigmoid
activation_fn: Callable[..., Any] = tanh
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = orthogonal()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
carry_init_fn: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
dropout_rate: float = 0.0
recurrent_dropout_rate: float = 0.0
def setup(self):
self.dropout = Dropout(rate=self.dropout_rate)
self.recurrent_dropout = Dropout(rate=self.recurrent_dropout_rate)
def initialize_carry(self, sample_inputs: Array):
rng = random.PRNGKey(0) if self.carry_init_fn is zeros \
else self.make_rng('carry')
key1, key2 = random.split(rng)
mem_shape = sample_inputs.shape[:-1] + (self.features,)
return (self.carry_init_fn(key1, mem_shape, self.dtype),
self.carry_init_fn(key2, mem_shape, self.dtype))
def create_mask(self, sample_inputs: Array, steps: int):
if self.dropout_rate > 0:
inputs_mask = self.dropout.create_mask((steps, *sample_inputs.shape))
else:
inputs_mask = None
if self.recurrent_dropout_rate > 0:
recurrent_mask = self.recurrent_dropout.create_mask(sample_inputs.shape)
recurrent_mask = jnp.stack([recurrent_mask] * steps, axis=0)
else:
recurrent_mask = None
return inputs_mask, recurrent_mask
@compact
def __call__(self, carry, inputs, mask=None):
c, h = carry
if mask is not None:
inputs_mask, recurrent_mask = mask
if inputs_mask is not None:
inputs = self.dropout.apply_mask(inputs, inputs_mask)
if recurrent_mask is not None:
h = self.recurrent_dropout.apply_mask(h, recurrent_mask)
hidden_features = h.shape[-1]
# input and recurrent layers are summed so only one needs a bias.
dense_h = partial(
Dense, features=hidden_features, use_bias=True, kernel_init=self.recurrent_kernel_init,
bias_init=self.bias_init, dtype=self.dtype, param_dtype=self.param_dtype)
dense_i = partial(
Dense, features=hidden_features, use_bias=False, kernel_init=self.kernel_init,
dtype=self.dtype, param_dtype=self.param_dtype)
i = self.gate_fn(dense_i(name='ii')(inputs) + dense_h(name='hi')(h))
f = self.gate_fn(dense_i(name='if')(inputs) + dense_h(name='hf')(h))
g = self.activation_fn(dense_i(name='ig')(inputs) + dense_h(name='hg')(h))
o = self.gate_fn(dense_i(name='io')(inputs) + dense_h(name='ho')(h))
new_c = f * c + i * g
new_h = o * self.activation_fn(new_c)
return (new_c, new_h), new_h
Thank you for this very comprehensive FLIP @cgarciae and @bastings, really impressive how much work you put into this and the proposal is extremely detailed.
Thoughts / Questions
- explain current API: Overall, I think it could be useful to briefly explain the current state of our API. We currently have a single
RNNCellBaseand 4 subclassesLSTMCell,OptimizedLSTMCell,GRUCell, andConvLSTM. Each of these subclasses implements the abstract methodinitialize_carry.... (then explain what this does...) - RNN abstraction: I agree with the proposal that having an RNN abstraction is useful, since it will simplify the masking and scanning logic. In your proposal you suggest adding
RNNBaseand subclassRNN(and optionallyGRU,LSTMas well). I have some questions about this:- What functionality would be in
LSTMthat is different from anRNNwith cellLSTMCell? - If there is no added functionality, and we decide to only add
RNN: why do we needRNNBasein the first place? Can't we just add a single classRNN, which is just a recurrent abstraction for one of the cells that can be passed in?
- What functionality would be in
- make initialize_carry a class method: I think this is a great idea! Indeed it is a breaking change, but as a first step we could see how many internal users we break and whether it is feasible to update them. Then we could implement a deprecation plan for the old API (similar to what we did for
flax.nnandflax.optim). I really like how it simplifies user code. - Bidirectional abstraction: You propose to add a single class
nn.Bidirectional. What if we just add an argumentbidirectionalto theRNNabstraction? In that case we might have the limitation that we can only choose the same cell for both the forward and reverse run, but the upshot is that we have one abstraction less to maintain. - Masking: It is not really clear from this proposal how exactly you are going to implement sequence packing. It mentioned some difficulties, but not how you will address them. It was clear to me how you intended to implement sequence length masking (the user provides the sequence lengths), but how does this look for packing? Which inputs should the user provide? How complex will this be?
- Recurrent dropout: I am missing a motivation for this. Forgive me my ignorance, but how important is this feature? Is this something that is absolute fundamental to any RNN, or rather a nice-to-have? If it isn't fundamental we could also think of implementing this later if it turns out there is a need for it. The change also seems to involve modifying our
Dropoutlayer, which may be fine but it does introduce more complexity.
Suggestions for implementation plan
One thing that is missing from this (quite big) FLIP is an implementation plan. Which parts are P0s, which are nice-to-haves, how long will each part take?
Thinking about this proposal and my comments above, what do you think of the following minimal (and perhaps sufficient) implementation plan? (Just a suggestion of course 😄 ):
- Add a single
RNNabstraction that encapsulates scanning over anRNNCellBaseinstance and add tests. - Add masking logic to the
RNNabstraction. In the proposal you mention sequence packing but I think there are still things unclear here. - Add a flag
bidirectionalto theRNNabstraction, which runs over the inputs forward and reverse with the sameRNNCellBase(e.g.,LSTMCell). - Improve the API for
initialize_carryand deprecate old API. - Add recurrent dropout.
Let me know what you think of all this!
Thank you @bastings and @cgarciae

Thanks for writing this up!
Some thoughts:
- Could you add the benchmark so we can decide whether the pre-computed mask is worth the added complexity (compared to using split_rngs={"dropout": False})
- Please also consider making the carry a proper linnen variable. This would bring cells more in line with the rest of Flax because now carry acts as a manually tracked variable
- I wonder if we can make cells and/or nn.scan flexible enough such that you can just use nn.scan over them directly. Right now RNNBase feels like a special scan primitive for cell.
- I don't see the added benefit of separate RNNBase from RNN so we can have specialized classes for LSTM and GRU. I think other frameworks have had such abstractions in the past mainly because they had specialized kernels which execute specific RNNs faster but we are not constrained by custom kernels.
FLIP has been moved to a #2585, feedback was integrated there but will be responding to some of the questions here for completeness:
explain current API: Overall, I think it could be useful to briefly explain the current state of our API. We currently have a single
RNNCellBaseand 4 subclassesLSTMCell,OptimizedLSTMCell,GRUCell, andConvLSTM. Each of these subclasses implements the abstract methodinitialize_carry.... (then explain what this does...)
This explanation was added in the "Motivation" section.
RNN abstraction: I agree with the proposal that having an RNN abstraction is useful, since it will simplify the masking and scanning logic. In your proposal you suggest adding
RNNBaseand subclassRNN(and optionallyGRU,LSTMas well). I have some questions about this:
- What functionality would be in
LSTMthat is different from anRNNwith cellLSTMCell?- If there is no added functionality, and we decide to only add
RNN: why do we needRNNBasein the first place?Can't we just add a single class
RNN, which is just a recurrent abstraction for one of the cells that can be passed in?
We agreed that it was indeed better to have a single RNN abstraction without a Base class. However, it is possible that specialized Modules like e.g. LSTM are added if an optimized kernel is provided (see jax.experimental.rnn).
make initialize_carry a method: I think this is a great idea! Indeed it is a breaking change, but as a first step we could see how many internal users we break and whether it is feasible to update them. Then we could implement a deprecation plan for the old API (similar to what we did for
flax.nnandflax.optim). I really like how it simplifies user code.
We chose not to implement any breaking changes, however this improvement could be done in the future along with other breaking changes. @jheek has some ideas how to improve nn.scan, some of which could eliminate the need for initialize_carry entirely.
Bidirectional abstraction: You propose to add a single class
nn.Bidirectional. What if we just add an argumentbidirectionalto theRNNabstraction? In that case we might have the limitation that we can only choose the same cell for both the forward and reverse run, but the upshot is that we have one abstraction less to maintain.
We've thought about this, but the implementation would hamper readability which goes against the Flax philosophy. We prefer to keep this in a Bidirectional class.
Masking: It is not really clear from this proposal how exactly you are going to implement sequence packing. It mentioned some difficulties, but not how you will address them. It was clear to me how you intended to implement sequence length masking (the user provides the sequence lengths), but how does this look for packing? Which inputs should the user provide? How complex will this be?
We decided not to implement sequence packing for now, if required this could be implemented in a separate e.g. PackedRNN which would have a slightly different carry shape. However, we are keeping the segmentation_mask format as its compatible with both packed and 'single sequence per row` implementations.
Recurrent dropout: I am missing a motivation for this. Forgive me my ignorance, but how important is this feature? Is this something that is absolute fundamental to any RNN, or rather a nice-to-have? If it isn't fundamental we could also think of implementing this later if it turns out there is a need for it. The change also seems to involve modifying our
Dropoutlayer, which may be fine but it does introduce more complexity.
Ultimately we decided to not go for the Mask APIs (removed from the FLIP) in favor for a simpler strategy which is to leverage Flax's PRNG Streams to differentiate between recurrent dropout and regular dropout. Dropout was recently updated to let you specify the RNG stream name. If performance is needed we would recommend users to levarage so of the options provided by the jax_default_prng_impl flag.
Thanks for the replies @cgarciae! I guess the last point that is unanswered until now are benchmarks. @jheek asked for them above, and @bastings told me that you indeed did some, but I can't find any evidence of this. Could you please add the results in the FLIP or reply here?