s4
s4 copied to clipboard
Return states for the entire sequence
I'm working on an application in which, during the training phase, I'd pass a tensor to the S4, and I'd like to start generating a rollout (in a recurrent mode) from any arbitrary point of the tensor (say, the j'th element of the i'th sequence). If I use RNN, I can save the hidden state fed to this step and use it to carry information from the past to generate a rollout. In S4, however, I just can randomly initialize a state to feed the step
function, which means I will lose some information. Is there a way to get the states of a sequence using the forward
function like what we do in the step
function?