mamba
mamba copied to clipboard
Consider allowing hidden state initialisation via ssm_state input parameter for selective_scan_fn
Please, please, consider adding the ssm_state input parameter for selective_scan_fn to allow hidden state initialisation for the Mamba block. Also please consider making hidden state differentiable as currently at selective_scan_fn we have:
Note that the gradient of the last state is not considered in the backward pass.
This change should potentially open the path for encoder-decoder Mamba architecture and for the encoder-only BERT-like architecture. The architecture analogous to RNNs would be - Mamba encoder goes through the input sequence ignoring output, the last hidden state then used to initialize the decoder with <START> input token and the decoder unrolls the state recursively. For the encoder to work last hidden state has to be differentiable. This also should open a route to encoder-only BERT architecture, classification/embedding problems, etc. For the decoder to work the Mamba block needs to be able to accept a hidden state at initialisation.
Related issues: #233 , #101
PS: Excellent work! Very impressive (especially the CUDA part)!