Backprop through time
Continuing our series "cool things we can't have yet", and inspired by this comment I was thinking about how we'll expose BPTT. Currently, given a forward pass like this:
for word in seq
loss += model(word)
end
loss
If we don't want to backprop over the whole sequence at once (gradient outside the loop) or over only a single step at a time (gradient inside the loop) then we need to split the loop as follows:
for chunks in seq
θ̄ = gradient() do
for word in chunk
loss += model(word)
end
loss
end
end
An alternative to this is to just expose primitives that let us fiddle with time steps directly. Consider:
record = Capacitor(5)
for x in seq
θ̄ = gradient() do
record() do
model(x)
end
end
end
Alright, bear with me here. This is written as if we were backprop-ing only across a single time step at a time, but with model evaluation wrapped in record. The idea is that record will log 5 previous backpropagators for the closure it is passed, and then chain these together for the backwards pass, which means we can actually backpropagate through n previous iterations of the loop -- i.e. backpropagation through time.
What's cool about this is that it makes BPTT completely orthogonal to the structure of the forward pass. The recorder can equally well be set up to backprop the last n steps each iteration (sliding window BTTF) or only every nth iteration (normal BTTF), or anything in between, and this can be set up differently for different parts of the model. It also isn't specific to any particular RNN implementation, e.g. this will work even though we have to backprop through h over loop iterations:
record = Capacitor(5)
h = ...
for word in seq
θ̄ = gradient() do
record() do
y, h = model(word)
loss(word, y)
end
end
end
The main question is whether this is actually going to be intuitive for people (who aren't travelling at 88mph). If it looks weird right now I think that's partly because we're not used to using gradient this way, so getting used to that will make the extra feature easier to reason about. At least for sliding windows, I think it's strictly better than flow-based alternatives.
What about the train_step! proposed at: https://github.com/FluxML/Flux.jl/pull/607#issuecomment-462070498 ?
I'm trying to translate some PyTorch that does BPTT and in the code I'm translating (ENAS-Pytorch) they seem to have put in an explicit loop inside their forward function to do the time steps (35 time steps in this case).
I suppose that's possible in with Flux's train! function: you could put the train! call inside of a loop that counts to the number of time steps you want for BPTT and give it a batch of data in each iteration - would that work?
Yeah, right now the Flux approach to this is essentially the same as PyTorch. step! and train! should basically be orthogonal issues; if you want to mentally rewrite the examples you should basically be able to replace gradient with step!, or the outer loop and gradient with train!.
Hi, where's the code implementing BPTT? (in recurrent.jl ?) I wrote some code doing BPTT using plain julia, and would like to compare the intermediate results with Flux. The fprop is easy to understand, but i didn't find where the bprop is implemented.
Could someone pointing me to where to check? Thanks. (a google search landed me on this page, which seems the best place to ask)
There is no code directly in Flux.jl implementing BPTT. This is "just" calling gradient over the whole loop like Mike did in his first code snippet.
does the "gradient" function accumulate gradients through time somehow? in case of seq2one, how does the function know how to loop? ideally, I would like to be able to intercept and verify the accumulation process.
for comparison, a manually written bptt has something like the following: Are there variables similar to "dWhh" kept somewhere? karpathy's code
for t in reversed(xrange(len(inputs))): dy = np.copy(ps[t]) dy[targets[t]] -= 1 dWhy += np.dot(dy, hs[t].T) dby += dy dh = np.dot(Why.T, dy) + dhnext # backprop into h dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity dbh += dhraw dWxh += np.dot(dhraw, xs[t].T) dWhh += np.dot(dhraw, hs[t-1].T) dhnext = np.dot(Whh.T, dhraw)
Yes, the AD backend, Zygote, will handle gradient accumulation through the loop. See this comment for how you can implement many to many or many to one models. Also check the recurrent docs.
thanks for your reply, darsnack. I posted another question in issue 144, as that one seems more recent. Do you mind taking a look?
Yes, the AD backend, Zygote, will handle gradient accumulation through the loop. See this comment for how you can implement many to many or many to one models. Also check the recurrent docs.