s4
s4 copied to clipboard
Multiple `setup_rnn` calls
Hey!
I'm attempting to integrate the Sashimi Backbone into some audio models -- I'd like to train in convolutional mode and run validation inference in RNN mode, but my reading of the code seems to imply that the setup_step
call isn't repeatable or reversible (#67 seems to imply this as well).
In the case that I temporarily want to infer in RNN mode, but then switch back to the convolutional training mode, what's my best option?
I think it should actually work fine calling it multiple times. You will just need to call it before every time you want to use the RNN mode.
Ah, I must've misread the code / the paper -- when we call setup_rnn
, the kernel state matrices are initialized for step calls only, but not for regular forward calls? In that, there's no leakage into the convolutional mode?
Edit: I think I get it -- the discrete matrices setup during the setup_step
calls aren't used during the convolutional pass.
While I have you, I'm attempting to understand the output of the Sashimi forward pass and how I'd do a simple cross entropy loss on the predictions. If I'm reading the Sashimi backbone right, in the UpPool blocks, we shift the inputs one over to the right (by padding left and removing the last element). Does that imply, if y = self.forward(x)
, the i
th element of y
is the prediction based on the < i
elements of x
?
Ah, I must've misread the code / the paper -- when we call
setup_rnn
, the kernel state matrices are initialized for step calls only, but not for regular forward calls? In that, there's no leakage into the convolutional mode?Edit: I think I get it -- the discrete matrices setup during the
setup_step
calls aren't used during the convolutional pass.
Yep, that's right!
While I have you, I'm attempting to understand the output of the Sashimi forward pass and how I'd do a simple cross entropy loss on the predictions. If I'm reading the Sashimi backbone right, in the UpPool blocks, we shift the inputs one over to the right (by padding left and removing the last element). Does that imply, if y = self.forward(x), the ith element of y is the prediction based on the < i elements of x?
Like any autoregressive model, training is done using "teacher forcing" or doing a cross entropy loss on the predictions $p(x_i | x_0, ..., x_{i-1})$ all at once for all $i$. It suffices to have a strictly causal sequence-to-sequence map on all positions at once $(x_i){i < L} \mapsto (y_i){i < L}$ and using the cross entropy loss on $y_i$.
Note that nothing about this is specific to Sashimi. This setup is the same as the ubiquitous "Language Modeling" (LM) task, and any sequence model can be used as long as causality is enforced (e.g. Transformers with the triangular causal mask). The shifting in the Sashimi UpPool blocks is to enforce causality as you said.
Sweet, that lines up with my intuition. One last question for you: when trying to generate conditionally, my thought was to convert to RNN, initialize default state, then feed the conditioning sequence through the model to build up state non-auto-regressively (e.g. predict x_i
given x_i-1
, discarding output y_i
s), then auto-regressively generate the output sequence (e.g. predict y_i
given y_i-1
)
Does this line up with how one conditions S4 based models?
That's right. This is how the generation script does it.
Thank you, appreciate the quick response times.