recurrent example for docs
Motivation and description
Dealing with recurrent networks presents a lot of questions because it works rather differently from the stateless case.
I think it would be extremely helpful to have explicit examples: one for sequence-to-sequence and one for sequence-to-one.
Possible Implementation
I might come back and contribute this, but as I'm posting this I still don't think I'm doing this the intended way...
I haven’t used the library for recurrent nets, so would be interested to see how this works and am open to changes of API if necessary 👍
I've messed around with it more since writing this. Recurrent nets seem to require a fair amount of dedicated code, I'm not sure if FluxTraining.jl would be the place for all of it. In particular, I've found myself needing to write functions to:
- Predict sequence-to-sequence.
- Predict sequence-to-one.
- Recurrently predict future of a sequence after a seed.
- Each of these for batches.
Additionally I wonder if the way sequences are stored is uniform in the ecosystem. The flux documentation itself strongly suggests that sequences should be nested arrays rather than rank-3 arrays.
I have used Flux + FluxTraining quite a bit for recurrent models in the past. In general, you shouldn't need to do anything special. Most of the work is related primarily to Flux and how it expects the data. Here is the situation I almost always end up in, and it might be useful for you.
- There is a function
that creates aD x T
matrix of whereT
is time andD
is the feature dimension. - I can generate a vector of samples as
samples = [generate(T) for _ in 1:nsamples]
. - I use
to turn this into a sequence of batches (from a batch of sequences). This is the key step. - You can generate many batches repeating the above steps with
nsamples = batch_size
In general as long as you think of a single sample in your dataset as a single sequence, then you can adapt the steps above to get them into the sequence of batches (samples) that Flux wants.
From there, achieving the different tasks is all in the loss function.
# seq to seq prediction
function seq2seq_loss(loss_fn)
function _loss(m, xs, ys)
yhats = [m(xi) for xi in xs]
return mean(loss_fn(yhat, yi) for (yhat, yi) in zip(yhats, yi))
return _loss
# seq to one prediction
function seq2one_loss(loss_fn)
function _loss(m, xs, ys)
yhats = [m(xi) for xi in xs]
return loss_fn(yhats[end], ys[end])
return _loss
# samplers for mapping the previous token to the next token
# used below in sample_model
sample_softmax(y::AbstractVector) =
Flux.onehot(rand(Categorical(softmax(y))), 1:length(y))
function sample_softmax(y::AbstractMatrix)
ŷs = [rand(Categorical(y)) for y in eachcol(softmax(ys))]
return Flux.onehotbatch(ŷs, 1:size(y, 1))
sample_best(ys::AbstractVecOrMat) = Flux.onehot(argmax(ys; dims = 1), 1:size(ys, 1))
# recurrently predict a sequence given a primer input sequence
function sample_model(model, nseq, primer = [], sampler = identity)
tokens = [model(x) for x in primer]
ncurrent = length(tokens)
while ncurrent < nseq
nexttoken = model(sampler(last(tokens)))
push!(tokens, nexttoken)
ncurrent += 1
return tokens
Note that batching does not affect any of the functions above. As long as you get the "sequence of batches" format right, you should be good.
If you still want to express all this using FluxTraining, then the following is something I've used in the past.
get_inout_seq(xs::AbstractVector) = xs[1:(end - 1)], xs[2:end]
get_inout_seq(xs::NTuple{2}) = xs[1], xs[2]
struct BPTTTrainingPhase <: AbstractTrainingPhase end
function FluxTraining.step!(learner, phase::BPTTTrainingPhase, batch)
xs, ys = get_inout_seq(batch)
FluxTraining.runstep(learner, phase, (xs = xs, ys = ys)) do handle, state
state.grads = gradient(learner.params) do
state.ŷs = [learner.model(xi) for xi in state.xs]
state.loss = learner.lossfn(state.ŷs, state.ys)
return state.loss
Flux.update!(learner.optimizer, learner.params, state.grads)
struct BPTTValidationPhase <: AbstractValidationPhase
BPTTValidationPhase() = BPTTValidationPhase(0, identity)
BPTTValidationPhase(nfeedback) = BPTTValidationPhase(nfeedback, identity)
function FluxTraining.step!(learner, phase::BPTTValidationPhase, batch)
xs, ys = get_inout_seq(batch)
FluxTraining.runstep(learner, phase, (xs = xs, ys = ys)) do _, state
n = length(state.xs) - phase.nfeedback
# n steps where input drives model
state.ŷs = [learner.model(state.xs[i]) for i in 1:n]
# nfeedback steps where the model drives itself
for _ in (n + 1):length(state.xs)
ŷ = phase.sampler(state.ŷs[end])
push!(state.ŷs, learner.model(ŷ))
state.loss = learner.lossfn(state.ŷs, state.ys)
I don't need to do this for training recurrent models, but I found it nice for a particular project where BPTT was the thing I was comparing against. Specifically, BPTTValidationPhase
is nice for allowing evaluating models in the recurrently driven mode where they feed their own input.
If your data is already in a big rank-3 array, then you can make your axis order as feature x samples x time
, and use Base.Iterators
or MLUtils.jl to partition this along second axis into a vector of feature x batch x time
chunks. A Recur
model in Flux should consume these chunks correctly.
Otherwise, I find the approach of treating each sample as a self-contained time series is the most intuitive and compatible with existing data wrangling/loading packages like MLUtils.jl. Just remember to batchseq
before passing to the Flux model.
Note that we actually do support 3D arrays of shape (features, batch, timesteps)
as inputs to RNN layers. The reason it's not documented/advertised is we're not sure whether the API makes sense. For example, how do you differentiate between a batched sequence input to a normal RNN and one timestep of input to a conv-based RNN? The current implementation also does the same partitioning by timesteps you'd do by hand internally, so it should be slower than Kyle's suggestion above.
Note I edited my comments from the original to correct a mistake in the order of the axis dimensions. Clearly, the time I've been spending with Jax recently is leaking...
i'm trying to understand how Zygote does the gradient accumulation, in case of a RNN. In the following I'm comparing the result with a manual gradient accumulation, and the result is different. What could be the reason here? The code is self-contained and runnable.
using Flux
using Random
# x in format (feature, samples, timesteps)
x = reshape([0.84147096, 0.9092974, 0.14112], 1, 1, 3)
y = -0.7568025
layer1 = Flux.Recur(Flux.RNNCell(1 => 5, tanh))
layer2 = Flux.Dense( 5 => 1 )
model = Flux.Chain(layer1, layer2)
e, g = Flux.withgradient(model, x, y) do m, xi, yi
yhat = [m(xi[:,:,i]) for i in 1:3] # timesteps = 3
return Flux.mse(yhat[3], yi)
println("flux gradient dWx: ", g[1][1][1].cell.Wi)
#-------- get individual gradients at each step -----------------
c1 = deepcopy(layer1.cell)
c2 = deepcopy(c1)
c3 = deepcopy(c2)
h0 = zeros(5, 1) # initial state zero
e3, f = Flux.withgradient(c1.Wi, c2.Wi, c3.Wi,
c1.Wh, c2.Wh, c3.Wh,
c1.b, c2.b, c3.b) do Wi1, Wi2, Wi3, Wh1, Wh2, Wh3, b1, b2, b3
h1 = tanh.( Wi1 * x[:,:,1] + Wh1 * h0 + b1); y1 = layer2(h1)
h2 = tanh.( Wi2 * x[:,:,2] + Wh2 * h1 + b2); y2 = layer2(h2)
h3 = tanh.( Wi3 * x[:,:,3] + Wh3 * h2 + b3); y3 = layer2(h3)
Flux.mse(y3, y)
println("accumulated dWx: ", f[1]+f[2]+f[3])